CSCG 2023 - Traps

Last modification on

Challenge

A flag checker. Cheap entertainment.

Challenge Files

Overview

We are provided with a single binary, a flag-checker called traps. The objective is to understand how flags are checked for validity, and generate a single input (the real flag) that satisfies all conditions.

Analysis

We begin reverse-engineering and immediately notice that the program uses ptrace to prevent other processes from debugging it. Additionally, the first function called in main (fcn.004025b0) seems like it doesn't do anything but compute a value without using the result, and outputting "NO :(" before exiting. Not very reversing friendly.

On closer inspection, however, we can decipher what is happening and that the actual input validation is implemented elsewhere.

The function fcn.004025b0 is called in main with the process pid as its argument:

0x00401581      e80a1f0300     call fcn.__getpid
0x00401586      4863d8         movsxd rbx, eax
0x00401589      89df           mov edi, ebx
0x0040158b      e820100000     call fcn.004025b0
0x00401590      e8ab120300     call fcn.fork

Inside fcn.004025b0 we find that a child process is forked, which begins ptracing the parent and modifies its registers.

0x004025c2      char input[0x40]
0x004025c2      if (__fork() != 0)
0x00402681          wait(nullptr)
0x004026ac          puts("Welcome to your average flag cheā€¦")
0x004026c0          _IO_fgets(&input, 0x40, stdout)
0x004026c8          no_and_exit(&input)
0x004026c8          noreturn
0x004025db      if (ptrace(PTRACE_ATTACH, arg1, nullptr, nullptr) s>= 0)
0x004025ed          int32_t var_fc
0x004025ed          int32_t rax_2 = waitpid(arg1, &var_fc, 0x40000000)
0x004025f2          int16_t rdx_1 = var_fc.w
0x00402611          if (rax_2 s< 0 || (rax_2 s>= 0 && (rdx_1.b & 0x7f) + 1 s> 1) || (rax_2 s>= 0 && (rdx_1.b & 0x7f) + 1 s<= 1 && (rdx_1.b & 0x7f) == 0))
0x004026d0              exit(status: zx.d(rdx_1:1.b))
0x004026d0              noreturn
0x00402611          if (rax_2 s>= 0 && (rdx_1.b & 0x7f) + 1 s<= 1 && (rdx_1.b & 0x7f) != 0 && ptrace(PTRACE_GETREGS, arg1, nullptr, &input) s>= 0)
0x0040263d              input[0].q = 0
0x00402653              if (ptrace(PTRACE_SETREGS, arg1, nullptr, &input) s>= 0 && ptrace(PTRACE_DETACH, arg1, nullptr, nullptr) s>= 0)
0x0040266e                  exit_2(0)
0x0040266e                  noreturn
0x0040269b      exit(status: 0xffffffff)
0x0040269b      noreturn

It reads the register contents into the same memory region occupied by the user input in the child process (causing the input[0] lifting artifact), and sets the first quad-word member inside the register struct to 0. This corresponds to the register r15 of the parent process. After this, the child detaches and exits.

Subsequently, the parent who was waiting at 0x00402681 is awoken by the child exit. Here, the lifter hides the control flow allowing this function to return normally, since the branch comparing r15 after wait seems like it should only ever resolve one way.

0x00402681      e8da000300     call fcn.wait
0x00402686      4c89f8         mov rax, r15
0x00402689      4885c0         test rax, rax
0x0040268c      7512           jne 0x4026a0
0x0040268e      4881c4f80000.  add rsp, 0xf8
0x00402695      5b             pop rbx
0x00402696      5d             pop rbp
0x00402697      c3             ret

This means that during normal execution, the calculations and greeter prints in fcn.004025b0 aren't even called!

Considering the rest of the main function, we find more ptrace shenanigans.

main
0x0040158b      bait(sx.d(__getpid()))
0x00401590      pid_t child = __fork()
0x00401597      int64_t ret
0x00401597      if (child == 0)
0x004015a3          ret = ptrace(PTRACE_TRACEME, 0, nullptr, nullptr)
0x004015ab          if (ret s>= 0)
0x004015b6              raise(0x13)
0x004015c0              child = child_run()
0x004015ab      if ((child == 0 && ret s>= 0) || child != 0)
0x004015d6          int32_t child_status
0x004015d6          int32_t* child_status_1 = &child_status
0x004015db          int32_t ret_1 = waitpid(child, &child_status, 0x40000000)
0x004015e0          int16_t child_rc = child_status.w
0x004015e6          if (ret_1 s< 0)
0x004017de              exit_label8:
0x004017de              exit(status: zx.d(child_rc:1.b))
0x004017de              noreturn
0x004015f6          if ((child_rc.b & 0x7f) + 1 s> 1)
0x004015f6              goto exit_label8
0x004015ff          if ((child_rc.b & 0x7f) == 0)
0x004015ff              goto exit_label8
0x0040161d          if (ptrace(PTRACE_SETOPTIONS, child, nullptr, 0x100000) s>= 0 && ptrace(PTRACE_CONT, child, nullptr, 0x12) s>= 0)
0x00401648              uint8_t func_len = 0x7f
0x0040164f              uint8_t* len = &chunk_lens
0x00401656              uint64_t* key = &chunk_keys
0x0040165d              int64_t pos = 0
0x0040167c              int32_t ret_2
0x0040167c              while (true)
0x0040167c                  ret_2 = waitpid(child, child_status_1, 0x40000000)
0x00401681                  child_rc = child_status.w
0x00401687                  if (ret_2 s< 0)
0x00401687                      break
0x00401697                  if ((child_rc.b & 0x7f) + 1 s> 1)
0x00401697                      break
0x004016a0                  if ((child_rc.b & 0x7f) == 0)
0x004016a0                      break
0x004016c1                  uint64_t regs[0x1d]
0x004016c1                  if (ptrace(PTRACE_GETREGS, child, nullptr, &regs) s< 0)
0x004016c1                      break
0x004016c7                  uint64_t func_len_1 = zx.q(func_len)
0x004016ea                  void* r12_2 = regs[0x10]
0x004016f2                  regs[0xd].o = func_len_1.q | *key << 0x40
0x00401702                  if (ptrace(PTRACE_SETREGS, child, nullptr, &regs) s< 0)
0x00401702                      break
0x00401715                  void* rbx_1 = r12_2 + 0x86
0x0040171d                  uint8_t rax_17 = func_len u>> 3
0x00401720                  int64_t i = 0
0x0040172e                  if (rax_17 != 0)
0x0040175a                      int64_t rax_19
0x0040175a                      do
0x00401741                          rax_19 = ptrace(PTRACE_POKETEXT, child, rbx_1, *(&chunk_codes + pos + (i << 3)))
0x00401749                          if (rax_19 s< 0)
0x00401749                              break
0x0040174f                          i = i + 1
0x00401753                          rbx_1 = rbx_1 + 8
0x00401753                      while (i != zx.q(rax_17))
0x00401749                      if (rax_19 s< 0)
0x00401763                          break
0x00401763                  pos = pos + func_len_1
0x0040177d                  if (ptrace(PTRACE_CONT, child, nullptr, 0x12) s< 0)
0x0040177d                      break
0x0040177f                  func_len = *len
0x00401783                  key = &key[1]
0x00401789                  len = &len[1]
0x0040178f                  if (func_len == 0)
0x004017a1                      int32_t ret_3 = waitpid(child, child_status_1, 0x40000000)
0x004017a6                      child_rc = child_status.w
0x004017bd                      if (ret_3 s>= 0 && (child_rc.b & 0x7f) + 1 s<= 1 && (child_rc.b & 0x7f) != 0)
0x004017d2                          return 0
0x004017aa                      break
0x00401687              if (ret_2 s< 0)
0x004017d6                  goto exit_label8
0x004017d6      exit(status: 0xffffffff)
0x004017d6      noreturn

First, another child is spawned that enables ptracing of its process through PTRACE_TRACEME and waits for a tracer to attach through raise(SIGSTOP).

The parent waits for the child to stop, then makes a PTRACE_SETOPTIONS request with PTRACE_O_EXITKILL to ensure that the child is killed if the parent exits, and resumes child execution through PTRACE_CONT. The parent then enters a loop in which values are read from an array in memory, and used to modify the rdi register, as well as the program memory at rip+0x86 in the child process. The child is awoken in each iteration through PTRACE_CONT, presumably to execute the new instructions copied into its memory.

Let's see what child_run does:

0x0040251b      void* rax = mmap(nullptr, 0x1000, 7, 0x22, 0, 0)
0x00402524      if (rax == -1)
0x0040252b          exit(status: 0xffffffff)
0x0040252b          noreturn
0x00402537      int64_t rdx = 0
0x0040254b      do
0x0040253d          *(rax + rdx) = *(rdx + 0x40246f)
0x00402540          rdx = rdx + 1
0x00402540      while (rdx != 0x87)
0x0040254d      jump(rax)

And at 0x0040246f..

0x0040246f      breakpoint
0x00402476      int64_t i = 0x100
0x00402484      uint8_t buf[0x101]
0x00402484      do
0x00402480          buf[i] = i.b
0x00402484          i = i - 1
0x00402484      while (i != 0)
0x00402489      int64_t i_2 = (i + 1) << 8
0x0040248d      uint8_t val = 0
0x004024bb      int64_t _i_2
0x004024bb      do
0x00402497          val = val + buf[i_2]
0x00402499          _i_2 = i_2
0x0040249c          i_2.b = i_2.b & 7
0x0040249f          i_2.b = i_2.b << 3
0x004024a8          val = val + (key u>> i_2.b).b
0x004024aa          int64_t __i_2 = _i_2
0x004024ad          _i_2.b = buf[__i_2]
0x004024b2          uint8_t tmp = buf[1 + val]
0x004024b2          buf[1 + val] = _i_2.b
0x004024b2          _i_2.b = tmp
0x004024b6          buf[__i_2] = _i_2.b
0x004024bb          i_2 = __i_2 - 1
0x004024bb      while (i_2 != 0)
0x004024bd      val = 0
0x004024ed      do
0x004024c9          i_2 = i_2 + 1
0x004024cc          int64_t xorb
0x004024cc          xorb.b = i_2.b
0x004024ce          val = val + buf[1 + xorb]
0x004024d1          _i_2.b = buf[1 + xorb]
0x004024d5          uint64_t tmp_2 = buf[1 + val]
0x004024d5          buf[1 + val] = _i_2.b
0x004024d5          _i_2.b = tmp_2
0x004024d9          buf[1 + xorb] = _i_2.b
0x004024dd          xorb.b = buf[1 + val]
0x004024e0          xorb.b = xorb.b + _i_2.b
0x004024e3          xorb.b = buf[1 + xorb]
0x004024e6          *(i_2 + 0x4024f5) = *(i_2 + 0x4024f5) ^ xorb.b
0x004024e6      while (len != i_2)
0x004024f6      return val

First, 0x87 bytes of instructions from 0x0040246f are copied into a heap chunk and executed. An initial breakpoint causes the child to halt execution until receiving PTRACE_SIGCONT from the parent. When the child is halted at 0x0040246f the parent writes to rip+0x86 (0x00402470+0x86 = 0x4024f6), inserting instructions and overwriting the final return. The RC4-like encryption algorithm after the breakpoint decrypts the copied memory in-place and subsequently executes it.

To find out what is being executed, we can replicate the decryption on the chunk array and analyze each chunk's functionality. I opted to reimplement the decryption in Python, however, instrumenting this code blob using the Unicorn Engine is also viable.

extract.py
import struct
import os
import os.path

lines = open("./blobs.hex").read().split("\n")
chunk_keys = bytes.fromhex(lines[0])
chunk_lens = bytes.fromhex(lines[1])
chunk_codes = bytes.fromhex(lines[2])

def rc4_decrypt(key, chunk, version=1):
    buf = [(i & 0xff) for i in range(0x101)]

    # key scheduling
    val = 0
    for i in range(0x100, 0, -1):
        val = (val + buf[i]) & 0xff
        key_shift = (i % 8) * 8
        key_byte = (key >> key_shift) & 0xff
        val = (val + key_byte) & 0xff
        tmp = buf[1 + val]
        buf[1 + val] = buf[i]
        buf[i] = tmp

    if version == 1:
        val = 0
    else:
        val = 0x90

    chunk = list(chunk)
    for i in range(1, len(chunk) + 1, 1):
        xorb = i & 0xff
        val = (val + buf[1 + xorb]) & 0xff
        tmp = buf[1 + val]
        buf[1 + val] = buf[1 + xorb]
        buf[1 + xorb] = tmp
        xorb = (buf[1 + val] + buf[1 + xorb]) & 0xff
        xorb = buf[1 + xorb]
        chunk[i-1] ^= xorb
    chunk = bytes(chunk)

    return chunk

pos = 0
chunk_lens = [0x7f] + list(chunk_lens)
for i,chunk_len in enumerate(chunk_lens):
    if chunk_len == 0:
        break
    chunk_key = struct.unpack("<Q", chunk_keys[i*8:i*8+8])[0]
    chunk_code = chunk_codes[pos:pos+chunk_len]
    if i < 3:
        dec_chunk = rc4_decrypt(chunk_key, chunk_code, version=1)
    else:
        dec_chunk = rc4_decrypt(chunk_key, chunk_code, version=2)
    if not os.path.exists(f"dec/chunk.{pos:04x}"):
        with open(f"dec/chunk.{pos:04x}", "wb+") as f:
            f.write(dec_chunk)
    pos += chunk_len

Reversing what each chunk does is not exactly as straight-forward as reading its standalone disassembly since chunks often reference parts of the program outside themselves and relative to where they are mapped. One trick to dealing with this is to map the chunks into the original binary again (one at a time) at the position they would be in memory relative to the loader code, and use a disassembler to make sense of what is being referenced.

As it turns out, the first few chunks are responsible for outputting the greeter text, getting user input, xoring it with 0x0d and modifying the encryption algorithm (hence the version parameter).

The rest of the chunks validate the input through multiple checksums and finally output whether the flag was correct.

Each chunk follows a similar format:

/* load start of user input */
   0:   48 8d 5c 24 08          lea    rbx, [rsp+0x8]

/* compute some val in rdi, by multiplying dwords from
   compute blob (see below) with dwords from user input */
   5:   48 8d 15 2f 00 00 00    lea    rdx, [rip+0x2f]  # 0x3b
   c:   48 31 ff                xor    rdi, rdi
   f:   48 31 c0                xor    rax, rax
  12:   48 31 c9                xor    rcx, rcx
  15:   8b 04 8a                mov    eax, DWORD PTR [rdx+rcx*4]
  18:   0f af 04 8b             imul   eax, DWORD PTR [rbx+rcx*4]
  1c:   01 c7                   add    edi, eax
  1e:   fe c1                   inc    cl
  20:   48 83 f9 10             cmp    rcx, 0x10
  24:   72 ef                   jb     0x15

/* val = (val == 0x9b1bf3e5) */
  26:   2b 3d 4f 00 00 00       sub    edi, DWORD PTR [rip+0x4f]  # 0x7b
  2c:   f7 df                   neg    edi
  2e:   19 ff                   sbb    edi, edi
  30:   ff c7                   inc    edi

/* increment value on stack if correct */
  32:   58                      pop    rax
  33:   01 f8                   add    eax, edi
  35:   50                      push   rax

/* next chunk */
  36:   e9 c4 00 00 00          jmp    0xff

/* compute blob */
  3b:   07                      (bad)
  3c:   80 32 0a                xor    BYTE PTR [rdx], 0xa
  3f:   c0 47 41 c9             rol    BYTE PTR [rdi+0x41], 0xc9
  43:   ca 55 64                retf   0x6455
  46:   b7 61                   mov    bh, 0x61
  48:   8f                      (bad)
  49:   d6                      (bad)
  4a:   e0 89                   loopne 0xffffffffffffffd5
  4c:   b8 95 44 27 49          mov    eax, 0x49274495
  51:   37                      (bad)
  52:   26 90                   es nop
  54:   6c                      ins    BYTE PTR es:[rdi], dx
  55:   28 0c c6                sub    BYTE PTR [rsi+rax*8], cl
  58:   e3 0c                   jrcxz  0x66
  5a:   22 e2                   and    ah, dl
  5c:   72 bd                   jb     0x1b
  5e:   24 13                   and    al, 0x13
  60:   b1 f2                   mov    cl, 0xf2
  62:   1d be 0e 0e 04          sbb    eax, 0x40e0ebe
  67:   34 67                   xor    al, 0x67
  69:   89 f3                   mov    ebx, esi
  6b:   1c 22                   sbb    al, 0x22
  6d:   98                      cwde
  6e:   21 a5 41 3e 36 bd       and    DWORD PTR [rbp-0x42c9c1bf], esp
  74:   18 37                   sbb    BYTE PTR [rdi], dh
  76:   af                      scas   eax, DWORD PTR es:[rdi]
  77:   d0 5b 66                rcr    BYTE PTR [rbx+0x66], 1
  7a:   9a                      (bad)
  7b:   e5 f3                   in     eax, 0xf3
  7d:   1b                      .byte 0x1b
  7e:   9b                      fwait
  7f:   90                      nop
  80:   d5                      (bad)
  81:   d1                      .byte 0xd1

Exploit

We can extract the constraint information (compute blob and checksum) from each chunk exported and solve for an input string that satisfies these constraints using z3.

solve.py
from z3 import *
import struct

s = Solver()

chunk_excl = ["chunk.0000", "chunk.007f", "chunk.00db",
            "chunk.01b9", "chunk.0a49"]
flag = [BitVec(f"f{i}", 32) for i in range(0x40 // 4)]
chunks = os.listdir("dec")
count = 0
for chunk in chunks:
    if chunk in chunk_excl:
        continue
    count += 1
    data = open("dec/"+chunk, "rb").read()
    checkvals = [struct.unpack("<I", data[i:i+4])[0] for i in range(0x3b, 0x7b, 4)]
    checksum = struct.unpack("<I", data[0x7b:0x7b+4])[0]
    _checksum = 0
    for i,val in enumerate(checkvals):
        _checksum += val * flag[i]
    s.add(_checksum == checksum)
assert(count == 0x10)

while s.check() == sat:
    m = s.model()
    flagvals = [m[v].as_long() for v in flag]
    flagbytes = b"".join([struct.pack("<I", v) for v in flagvals])
    flagbytes = bytes([c ^ 0x0d for c in flagbytes])
    print(flagbytes)
    for i,v in enumerate(flag):
        s.add(v != flagvals[i])