Cryptography - Shiro Hero

Challenge:

For this crypto chall we have to download 4 files, 3 of them are python scripts and the last one is just a text file.

Chall.py

from secrets import randbits
from prng import xorshiro256
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from ecc import ECDSA
from Crypto.Util.number import bytes_to_long, long_to_bytes
import hashlib
flag = open("flag.txt", "rb").read()
state = [randbits(64) for _ in range(4)]
prng = xorshiro256(state)
leaks = [prng.next_raw() for _ in range(4)]
print(f"PRNG leaks: {[hex(x) for x in leaks]}")
Apriv, Apub = ECDSA.gen_keypair()
print(f"public_key = {Apub}")
msg = b"My favorite number is 0x69. I'm a hero in your mother's bedroom, I'm a hero in your father's eyes. What am I?"
H = bytes_to_long(msg)
sig = ECDSA.ecdsa_sign(H, Apriv, prng)                  
print(f"Message = {msg}")
print(f"Hash = {H}")
print(f"r, s = {sig}")
key = hashlib.sha256(long_to_bytes(Apriv)).digest()
iv = randbits(128).to_bytes(16, "big")
cipher = AES.new(key, AES.MODE_CBC, iv)
ciphertext = iv.hex() + cipher.encrypt(pad(flag, 16)).hex()
print(f"ciphertext = {ciphertext}")
with open("output.txt", "w") as f:
    f.write(f"PRNG leaks: {[hex(x) for x in leaks]}\n")
    f.write(f"public_key = {Apub}\n")
    f.write(f"Message = {msg}\n")
    f.write(f"Hash = {H}\n")
    f.write(f"r, s = {sig}\n")
    f.write(f"ciphertext = {ciphertext}\n")

ecc.py

#!/usr/bin/env python3
import random
from hashlib import sha3_256, sha256
from Crypto.Util.number import bytes_to_long, inverse
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad, pad
from prng import xorshiro256, MASK64     
import hashlib
import os

class ECDSA:
    """ECDSA implementation for secp256k1 curve"""
    # parameters
    p  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
    a  = 0
    b  = 7
    Gx = 55066263022277343669578718895168534326250603453777594175500187360389116729240
    Gy = 32670510020758816978083085130507043184471273380659243275938904335757337482424
    n  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
    G  = (Gx, Gy)

    @staticmethod   
    def digest(msg: bytes) -> int:
        """Hash a message and return as integer"""
        return bytes_to_long(sha256(msg).digest())

    @staticmethod
    def point_add(P, Q):
        """Add two points on the elliptic curve"""
        if P == (None, None): 
            return Q
        if Q == (None, None): 
            return P
        (x1, y1), (x2, y2) = P, Q
        if x1 == x2 and (y1 + y2) % ECDSA.p == 0: return (None, None)
        if P == Q:
            l = (3 * x1 * x1) * inverse(2 * y1, ECDSA.p) % ECDSA.p
        else:
            l = (y2 - y1) * inverse(x2 - x1, ECDSA.p) % ECDSA.p
        x3 = (l * l - x1 - x2) % ECDSA.p
        y3 = (l * (x1 - x3) - y1) % ECDSA.p
        return (x3, y3)
            @staticmethod
    def scalar_mult(k, P):
        R = (None, None)
        while k:
            if k & 1: R = ECDSA.point_add(R, P)
            P = ECDSA.point_add(P, P)
            k >>= 1
        return R

    @staticmethod
    def gen_keypair():
        d = random.randint(1, ECDSA.n - 1)         
        Q = ECDSA.scalar_mult(d, ECDSA.G)          
        return d, Q                                 

    @staticmethod
    def ecdsa_sign(h: int, d: int, prng: xorshiro256):
        while True:
            k = prng() % ECDSA.n
            if not k:
                continue
            x, _ = ECDSA.scalar_mult(k, ECDSA.G)
            if x is None:  
                continue
            r = x % ECDSA.n
            if not r:
                continue
            s = (inverse(k, ECDSA.n) * (h + r * d)) % ECDSA.n
            if s:
                return r, s

    @staticmethod
    def ecdsa_verify(h, Q, sig):
        r, s = sig
        if not (1 <= r < ECDSA.n and 1 <= s < ECDSA.n):
            return False
        w  = inverse(s, ECDSA.n)
        u1 = (h * w) % ECDSA.n
        u2 = (r * w) % ECDSA.n
        x, _ = ECDSA.point_add(ECDSA.scalar_mult(u1, ECDSA.G), ECDSA.scalar_mult(u2, Q))
        if x is None:  
            return False
        return (x % ECDSA.n) == r

prng.py

#!/usr/bin/python3
from Crypto.Util.number import bytes_to_long, inverse
MASK64 = (1 << 64) - 1                    

def _rotl(x: int, k: int) -> int:
    return ((x << k) | (x >> (64 - k))) & MASK64

class xorshiro256:
    
    def __init__(self, seed):
        if len(seed) != 4:
            raise ValueError("seed must have four 64-bit words")
        self.s = [w & MASK64 for w in seed]


    @staticmethod
    def _temper(s1: int) -> int:
        return (_rotl((s1 * 5) & MASK64, 7) * 9) & MASK64


    def next_raw(self) -> int:
        s0, s1, s2, s3 = self.s
        t = (s1 << 17) & MASK64

        s2 ^= s0
        s3 ^= s1
        s1 ^= s2
        s0 ^= s3            
        s2 ^= t
        s3  = _rotl(s3, 45)

        self.s = [s0, s1, s2, s3]
        return s1          
    
    def randrange(self, start, stop, inclusive=False):
        if inclusive:
            return start + self.next_raw() % (stop - start + 1)
        return start + self.next_raw() % (stop - start)

    def __call__(self) -> int:
        return self._temper(self.next_raw())

outuput.txt

PRNG leaks: ['0x785a1cb672480875', '0x91c1748fec1dd008', '0x5c52ec3a5931f942', '0xac4a414750cd93d7']
public_key = (108364470534029284279984867862312730656321584938782311710100671041229823956830, 13364418211739203431596186134046538294475878411857932896543303792197679964862)
Message = b"My favorite number is 0x69. I'm a hero in your mother's bedroom, I'm a hero in your father's eyes. What am I?"
Hash = 9529442011748664341738996529750340456157809966093480864347661556347262857832209689182090159309916943522134394915152900655982067042469766622239675961581701969877932734729317939525310618663767439074719450934795911313281256406574646718593855471365539861693353445695
r, s = (54809455810753652852551513610089439557885757561953942958061085530360106094036, 42603888460883531054964904523904896098962762092412438324944171394799397690539)
ciphertext = 404e9a7bbdac8d3912d881914ab2bdb924d85338fbd1a6d62a88d793b4b9438400489766e8e9fb157c961075ad4421fc

First of all the main file is chall.py, since it has all the algorithms to solve. The structure is the following, diferent ciphers that we need to solve in order to continue. We also have access to ecc.py and prng.py that show us the implementation for some functions, as well to some more data with output.txt.

Knowing all this, lets start.

Here 4 64 bits random numbers are generated and they are assigned to the state of a cipher called xorshiro256.

state = [randbits(64) for _ in range(4)]
prng = xorshiro256(state)

In xorshiro this 4 numbers (state) are defined as s0, s1, s2, s3. Here the “leaks” (in the output.txt) that are printed are not the s0,s1… numbers. We are just seeing s1 being updated 4 times due to the next_raw() function.

leaks = [prng.next_raw() for _ in range(4)]
print(f"PRNG leaks: {[hex(x) for x in leaks]}")
def next_raw(self) -> int:
        s0, s1, s2, s3 = self.s
        t = (s1 << 17) & MASK64

        s2 ^= s0
        s3 ^= s1
        s1 ^= s2
        s0 ^= s3            
        s2 ^= t
        s3  = _rotl(s3, 45)

        self.s = [s0, s1, s2, s3]
        return s1

AS you can see above the whole state (s0,s1,s2,3) is changed but we only see the s1.

Understanding this part is important because we are going to need to recover the 4 original numbers in order to continue. So again, prng is updated in 4 rounds with next_raw() and we have access to the value of s1 in each round.

For solving this I used a very useful tool called z3, I wont go into too much detail explaining the following script but we are basically forcing z3 to find the 4 original 64 bits numbers by reversing the leaks.

from z3 import *

MASK64 = (1 << 64) - 1

def rotl(x, k):
    return RotateLeft(x, k)

leaks = [
    0x785a1cb672480875,
    0x91c1748fec1dd008,
    0x5c52ec3a5931f942,
    0xac4a414750cd93d7
]

s0 = BitVec('s0', 64)
s1 = BitVec('s1', 64)
s2 = BitVec('s2', 64)
s3 = BitVec('s3', 64)

def next_raw_z3(s0, s1, s2, s3):
    t = (s1 << 17) & MASK64

    s2_new = s2 ^ s0
    s3_new = s3 ^ s1
    s1_new = s1 ^ s2_new
    s0_new = s0 ^ s3_new
    s2_new = s2_new ^ t
    s3_new = rotl(s3_new, 45)

    return s0_new, s1_new, s2_new, s3_new, s1_new 

solver = Solver()
state = [s0, s1, s2, s3]

for i in range(4):
    state[0], state[1], state[2], state[3], out = next_raw_z3(*state)
    solver.add(out == BitVecVal(leaks[i], 64))

if solver.check() == sat:
    model = solver.model()
    recovered_state = [model.eval(s).as_long() for s in [s0, s1, s2, s3]]
    print("Estado recuperado:", recovered_state)

We run the script and here they are. [4632343889369999961, 10793220881798324403, 12527397580889080479, 11809022490152434257]

Now we run into another type of cipher, an ECDSA sign.

Apriv, Apub = ECDSA.gen_keypair()
print(f"public_key = {Apub}")
msg = b"My favorite number is 0x69. I'm a hero in your mother's bedroom, I'm a hero in your father's eyes. What am I?"
H = bytes_to_long(msg)
sig = ECDSA.ecdsa_sign(H, Apriv, prng)   

As we can see in the prng.py, k, the nounce, is calculated with prng(). But this call, prng(), is done after the 4 rounds of next_raw() of the leaks.

    def ecdsa_sign(h: int, d: int, prng: xorshiro256):
        while True:
            k = prng() % ECDSA.n
            if not k:
                continue
            x, _ = ECDSA.scalar_mult(k, ECDSA.G)
            if x is None:  
                continue
            r = x % ECDSA.n
            if not r:
                continue
            s = (inverse(k, ECDSA.n) * (h + r * d)) % ECDSA.n
            if s:
                return r, s

We have the original state, so we call next_raw() 4 times and prng() once.

Before doing this lets take a look to the ECSA signing function. ECDSA

With some maths we can reorder this and since we have everything either in output.txt or ecc.py we can just calculate k and continue. s ≡ (k^⁻1)((H+r) d) (mod n) –> d ≡(s k-H) r^-1 (mod n)

from prng import xorshiro256

state = [4632343889369999961, 10793220881798324403, 12527397580889080479, 11809022490152434257]

prng = xorshiro256(state)

prng.next_raw()
prng.next_raw()
prng.next_raw()
prng.next_raw()
semi_K = prng()

n  = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141

k = semi_K % n

Hash = 9529442011748664341738996529750340456157809966093480864347661556347262857832209689182090159309916943522134394915152900655982067042469766622239675961581701969877932734729317939525310618663767439074719450934795911313281256406574646718593855471365539861693353445695
r = 54809455810753652852551513610089439557885757561953942958061085530360106094036
s = 42603888460883531054964904523904896098962762092412438324944171394799397690539

r_inv = pow(r, -1, n)
d = ((s*k-Hash)*r_inv) % n

print(d)

d: 100589891343820979015464582911071111464252983749550820544942776016668758604656

The next and final step is an AES CBC cipher. To have the key we just need to apply the same hash function to d. Whereas iv is concatenated to the cipher: ciphertext = 16bits of the iv + cipher. So we just need to divide them.

key = hashlib.sha256(long_to_bytes(Apriv)).digest()
iv = randbits(128).to_bytes(16, "big")
cipher = AES.new(key, AES.MODE_CBC, iv)
ciphertext = iv.hex() + cipher.encrypt(pad(flag, 16)).hex()

And here the script:

d = 100589891343820979015464582911071111464252983749550820544942776016668758604656

from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
from Crypto.Util.number import long_to_bytes
import hashlib

ciphertext_hex = "404e9a7bbdac8d3912d881914ab2bdb924d85338fbd1a6d62a88d793b4b9438400489766e8e9fb157c961075ad4421fc"

key = hashlib.sha256(long_to_bytes(d)).digest()
ciphertext_bytes = bytes.fromhex(ciphertext_hex)
iv = ciphertext_bytes[:16]
ciphertext_data = ciphertext_bytes[16:]

cipher = AES.new(key, AES.MODE_CBC, iv)
flag = unpad(cipher.decrypt(ciphertext_data), 16)

print("FLAG:", flag.decode())

FLAG: L3AK{u_4r3_th3_sh1r0_h3r0!}