Write-Ups

11 min read

CA CTF 2022: Breaking a custom hash function with z3 - Memory Acceleration

Breaking a custom hash function with z3, WizardAlfredo shares his write-up of Memory Acceleration from Cyber Apocalypse CTF 2022.

WizardAlfredo,
Jun 29
2022

In this writeup, we'll go over the solution for the medium-hard difficulty crypto challenge Memory Acceleration that requires the exploitation of a custom hash function using z3 and some minor brute forcing.

Description 📄

While everyone was asleep, you were pushing the capabilities of your technology to the max. Night after night, you frantically tried to repair the encrypted parts of your brain, reversing custom protocols implemented by your father, wanting to pinpoint exactly what damage had been done and constantly keeping notes because of your inability of forming new memories. On one of those nights, you had a flashback. Your father had always talked about a new technology and how it would change the galaxy. You realized that he had used it on you. This technology dealt with a proof of a work function and decentralized networks. Along with Virgil's help, you had a "Eureka!" moment, but his approach, brute forcing, meant draining all your energy. Can you find a quicker way to validate new memory blocks?

The application at-a-glance 🔍

When we try to connect to the tcp server with something like nc, some debug messages appear. 

┏━[~]
┗━━ ■ nc 178.62.119.24 30554
Virgil says:
Klaus I'm connecting the serial debugger to your memory.
Please stay still. We don't want anything wrong to happen.
Ok you should be able to see debug messages now..

DEBUG MSG - You need to validate this memory block: You don't have to add the z3 solver to your firmware ever again. Now you can use it forever.
DEBUG MSG - Enter first key: 

The server asks us to enter 2 keys. We can try to enter 2 random keys. A message appears that the proof of work is not correct and Virgil tells us that we have calculated something wrong.

DEBUG MSG - Enter first key: 1
DEBUG MSG - Enter second key: 2
DEBUG MSG - Incorect proof of work

Virgil says: 
You calculated something wrong Klaus we need to start over.

At this point, we need to start looking at the source code to understand how things work.

Analysing the source code 📖

There are 2 files available source.py, pofwork.py

source.py

If we look at the source.py script we can see that our goal is to somehow find a way to make the phash function output 0, 4 times.

The basic workflow of the script is as follows:

  1. A memory block is loaded and waits to be validated.
  2. 2 keys are requested from the debugging interface.
  3. The phash is calculated. If the proof_of_work is valid continue. If not exit()
  4. The new memory is appended to the previous block and the process is repeated.

Steps 1, 2, and 4 are not that interesting for analysis. We will go directly to the phash function.

    proof_of_work = phash(block, first_key, second_key)

phash is imported from another file so let's take a look at it.

from pofwork import phash

pofwork.py

We first see an sbox, the sub function, the rotl function, and the phash function.

The sub function is just a substitution function using the standard sbox from the Wikipedia of AES. So it is not interesting.
The rotl function is just a standard bit rotation function. Again, there is nothing interesting here.

Let's now look at the juicy phash function.

We see that there is a manipulation of the memory block that is taken as input.
The block is hashed. Then it is expanded and subsequently divided into blocks of size 4.

    block = md5(block.encode()).digest()
    block = 4 * block
    blocks = [int.from_bytes(block[i:i+4],'big') for i in range(0, len(block), 4)]

After that, some interesting variables are initialized. And it should be noted that key1 is used here.

    m = 0xffffffff
    rv1, rv2 = 0x2423380b4d045, 0x3b30fa7ccaa83
    x, y, z, u = key1, 0x39ef52e9f30b3, 0x253ea615d0215, 0x2cd1372d21d77

Then 13 rounds of chaos happen with a lot of bitwise operations and h is calculated from rv1.

    for i in range(13):
        x, y, z, u = blocks[i] ^ x, blocks[i+1] ^ y, blocks[i+2] ^ z, blocks[i+3] ^ u
        rv1 ^= (x := (x & m) * (m + (y >> 16)) ^ rotl(z, 3))
        rv2 ^= (y := (y & m) * (m + (z >> 16)) ^ rotl(x, 3))
        rv1, rv2 = rv2, rv1
        rv1 = sub(rv1)
        rv1 = bytes_to_long(rv1)

    h = rv1 + 0x6276137d7 & m

Finally, key2 is passed to sub, and some bitwise operations are performed with h and key2

     for i, d in enumerate(key2):
        a = (h << 1) & m
        b = (h << 3) & m
        c = (h >> 4) & m
        h ^= (a + b + c - d)
        h += h
        h &= m

    h *= u * z
    h &= m

    return h 

A little summary of all the interesting things we have found out so far:

  1. key1 is only used to compute h, until eventually 
  2. key2 and h are used for the final hash computation with some simpler bitwise operations.

Searching for the bugs 👾

There are many things wrong with this hash function, but our goal is to somehow get it to output 0. If we break it down to simpler problems, we can see that it has a part that can be predicted by z3.

Exploitation 🔓

Connecting to the server

A pretty basic script for connecting to the server with pwntools:

if __name__ == '__main__':
    r = remote('0.0.0.0', 1337)
    pwn()

Getting the first point

When someone connects to the server, a block of memory is calculated. To get it from the server, we can use:

def getBlockToValidate():
    debug_msg = r.recvline()
    block = debug_msg.decode().strip()[len('DEBUG MSG - You need to validate this memory block: '):]
    return block

z3 magic

As indicated above, the hash function must be broken into sections to solve z3 the system.
More precisely, we can ignore the entire first half of the hash function and focus only on the second half.

    h = rv1 + 0x6276137d7 & m
    key2 = sub(key2)

    for i, d in enumerate(key2):
        a = (h << 1) & m
        b = (h << 3) & m
        c = (h >> 4) & m
        h ^= (a + b + c - d)
        h += h
        h &= m

    h *= u * z
    h &= m

How can we further simplify the problem for z3?

We know that h is deterministic for the pair block, key1.
We can also see that sub(key2) is invertible and that the last 2 lines are irrelevant if h is to be 0.

So the problem is:

If we can control the key and h is a known value how can we make h = 0 after the bitwise operations?

    for i, d in enumerate(key2):
        a = (h << 1) & m
        b = (h << 3) & m
        c = (h >> 4) & m
        h ^= (a + b + c - d)
        h += h
        h &= m

z3 to the rescue. Let us make a model.

  1. We create a solver.
  2. Declare the bit vectors
  3. Add our constraints.
  4. Solve the model if it is solvable.
def findSecondKeyWithZ3(block, key1):
    s = Solver()

    xs = list(BitVecs('c0 c1 c2 c3 c4 c5', 32))
    h = BitVec('hs', 32)

    target_h = phashFirstHalf(block, key1)
    s.add(h == target_h)

    for i, e in enumerate(xs):
        s.add(e >= 0)
        s.add(e < 255)
        h ^= ((h << 1) + (h << 3) + rotr(h, 4) - e)
        h += h

    s.add(h == 0)

    if (s.check() != unsat):
        m = s.model()
        s = bytes([m[c].as_long() for c in xs])

        assert phashSecondHalf(s, target_h) == 0
        return(s)
    else:
        return('unsat')

Brute forcing key1

You may wonder why findSecondKeyWithZ3() is a function that requires key1 as input. Well, not every key1 produces a solvable system for z3, so we need to brute force the solution a bit.

def findKeys(block):
    first_key = 0
    while 1:
        second_key = findSecondKeyWithZ3(block, first_key)
        if (second_key != 'unsat'):
            return(first_key, bytes_to_long(isub(second_key)))
        first_key += 1

Sending the keys

def sendKeys(first_key, second_key):
    r.sendlineafter(b'DEBUG MSG - Enter first key: ', bytes(str(first_key), 'Latin'))
    r.sendlineafter(b'DEBUG MSG - Enter second key: ', bytes(str(second_key), 'Latin'))

Getting the flag

A final summary of all that was said above:

  1. We have connected to the server.
  2. We have obtained the memory block.
  3. We found that part of our hash function can be solved with z3.
  4. We brute forced key1 so that the z3 system is solvable.
  5. We ran our solver and found key2.
  6. Finally, we send the keys and get the next block.
  7. Repeat this process 4 times and we get the flag :)

This recap can be reprisented by code with the pwn() function.

def pwn():
    for _ in range(4):
        skipMessages()
        block = getBlockToValidate()
        print(f"Fetched block of memory that needs validation")
        first_key, second_key = findKeys(block)
        print(f"Found valid keys: {first_key}, {second_key} for block: {block}")
        sendKeys(first_key, second_key)
        print("Moving on to the next block")
    r.interactive()

The final script is:

import random
from hashlib import md5
from Crypto.Util.number import *
from z3 import *
from pwn import *


sbox = [
    0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
    0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
    0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
    0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
    0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
    0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
    0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
    0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
    0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
    0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
    0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
    0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
    0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
    0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
    0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
    0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
]


def rotl(n, b):
    return ((n << b) | (n >> (32 - b))) & 0xffffffff


def rotr(v, n):
    return (v >> n) & ((1 << (32 - n)) - 1)


def sub(b):
    b = long_to_bytes(b)
    return bytes([sbox[i] for i in b])


def isub(b):
    return bytes([sbox.index(i) for i in b])


def phashFirstHalf(text, key1):
    text = md5(text.encode()).digest()
    text = 4 * text
    text = [int.from_bytes(text[i:i+4], 'big') for i in range(0, len(text), 4)]

    m = 0xffffffff
    rv1, rv2 = 0x2423380b4d045, 0x3b30fa7ccaa83
    x, y, z, u = key1, 0x39ef52e9f30b3, 0x253ea615d0215, 0x2cd1372d21d77

    for i in range(13):
        x, y, z, u = text[i] ^ x, text[i+1] ^ y, text[i+2] ^ z, text[i+3] ^ u
        rv1 ^= (x := (x & m) * (m + (y >> 16)) ^ rotl(z, 3))
        rv2 ^= (y := (y & m) * (m + (z >> 16)) ^ rotl(x, 3))
        rv1, rv2 = rv2, rv1
        rv1 = sub(rv1)
        rv1 = bytes_to_long(rv1)

    h = rv1 + 0x6276137d7 & m
    return h


def phashSecondHalf(s, target_h):
    h = target_h
    for i, d in enumerate(s):
        a = (h << 1) & 0xffffffff
        b = (h << 3) & 0xffffffff
        c = (h >> 4) & 0xffffffff
        h ^= (a + b + c - d)
        h += h
        h &= 0xffffffff
    return h


def skipMessages():
    for _ in range(5):
        r.recvline()


def getBlockToValidate():
    debug_msg = r.recvline()
    block = debug_msg.decode().strip()[len(
        'DEBUG MSG - You need to validate this memory block: '):]
    return block


def findSecondKeyWithZ3(block, key1):
    s = Solver()

    target_h = phashFirstHalf(block, key1)

    xs = list(BitVecs('c0 c1 c2 c3 c4 c5', 32))
    h = BitVec('hs', 32)

    s.add(h == target_h)

    for i, e in enumerate(xs):
        s.add(e >= 0)
        s.add(e < 255)
        h ^= ((h << 1) + (h << 3) + rotr(h, 4) - e)
        h += h

    s.add(h == 0)

    if (s.check() != unsat):
        m = s.model()
        s = bytes([m[c].as_long() for c in xs])

        assert phashSecondHalf(s, target_h) == 0
        return(s)
    else:
        return('unsat')


def findKeys(block):
    first_key = 0
    while 1:
        second_key = findSecondKeyWithZ3(block, first_key)
        if (second_key != 'unsat'):
            return(first_key, bytes_to_long(isub(second_key)))
        first_key += 1


def sendKeys(first_key, second_key):
    r.sendlineafter(b'DEBUG MSG - Enter first key: ',
                    bytes(str(first_key), 'Latin'))
    r.sendlineafter(b'DEBUG MSG - Enter second key: ',
                    bytes(str(second_key), 'Latin'))


def pwn():
    for _ in range(4):
        skipMessages()
        block = getBlockToValidate()
        print(f"Fetched block of memory that needs validation")
        first_key, second_key = findKeys(block)
        print(
            f"Found valid keys: {first_key}, {second_key} for block: {block}")
        sendKeys(first_key, second_key)
        print("Moving on to the next block")
    r.interactive()


if __name__ == '__main__':
    r = remote('localhost', 1337)
    pwn()
 

If you want to dive deeper into z3, deuterium created a really cool and detailed write-up that can be found here.

And that’s a wrap for this challenge write-up!

Hack The Blog

The latest news and updates, direct from Hack The Box