GreyCTF 2023 - OT

Last modification on

Challenge

Oblivious Transfer is amazing :D

Challenge Files | Solution Files

main.py
import secrets
import hashlib
from Crypto.Util.number import isPrime, long_to_bytes

FLAG = b'grey{fake_flag}'

e = 0x10001

def checkN(N):
    if (N < 0):
        return "what?"
    if (N.bit_length() != 4096):
        return "N should be 4096 bits"
    if (isPrime(N) or isPrime(N + 23)):
        return "Hey no cheating"
    return None

def xor(a, b):
    return bytes([i ^ j for i,j in zip(a,b)])

def encrypt(key, msg):
    key = hashlib.shake_256(long_to_bytes(key)).digest(len(msg))
    return xor(key, msg)

print("This is my new Oblivious transfer protocol built on top of the crypto primitive (factorisation is hard)\n")
print("You should first generate a number h which you know the factorisation,\n")
print("If you wish to know the first part of the key, send me h")
print(f"If you wish to know the second part of the key, send me h - {23}\n")

N = int(input(("Now what's your number: ")))

check = checkN(N)
if check != None:
    print(check)
    exit(0)

k1, k2 = secrets.randbelow(N), secrets.randbelow(N)
k = k1 ^ k2

print("Now I send you these 2 numbers\n")
print(f"pow(k1, e, N) = {pow(k1, e, N)}")
print(f"pow(k2, e, N+23) = {pow(k2, e, N + 23)}\n")

print("Since you only know how to factorise one of them, you can only get one part of the data :D\n")
print("This protocol is secure so sending this should not have any problem")
print(f"flag = {encrypt(k, FLAG).hex()}")
print("Bye bye!")

Overview

We are presented with the source code of a Python servlet that asks the client for a number as input and outputs an encrypted flag. The goal is to choose the input number such that the resulting flag may be decrypted without knowledge of k1 and k2, which are chosen randomly in the encryption process.

The input number corresponds to a positive 4096-bit modulus N, and determines the bound for k1 and k2 chosen by secrets.randbelow(N). The values pow(k1, 0x10001, N) and pow(k2, 0x10001, N+23) are returned to the user, along with the flag, encrypted with SHAKE256 using k1 ^ k2 as the encryption key. The server enforces that neither N nor N+23 are prime.

Analysis

To recover k1 and k2 from the result of their modular exponentiation, it is necessary to calculate the value of the Carmichael's totient function λ(n). Like in RSA, this is to determine the value d which can reverse the modular exponentiation, calculated as the modular inverse of e in respect to λ(n): such that so long as m is coprime to n.

As such, the goal is to choose a number N such that the prime factorization of N and N+23 are known.

Solution

If we choose N = 23*n such that N+23 = 23*(n+1), the problem is reduced to finding a n and n+1 whose prime factorization is known. Choosing p1 = n+1 as a safe prime ensures that n = 2*p2 where p2 is the corresponding Sophie Germain prime.

Thus the prime factorization of N and N+23 are (23,2,p2) and (23,p1) respectively.

solve.py
from pwn import *
from math import lcm
from Crypto.Util.number import isPrime, long_to_bytes

def checkN(N):
    if (N < 0):
        return "what?"
    if (N.bit_length() != 4096):
        return "N should be 4096 bits"
    if (isPrime(N) or isPrime(N + 23)):
        return "Hey no cheating"
    return None

def xor(a, b):
    return bytes([i ^ j for i,j in zip(a,b)])

def decrypt(key, msg):
    key = hashlib.shake_256(long_to_bytes(key)).digest(len(msg))
    return xor(key, msg)

# generated with: openssl dhparam 4091
# might take multiple tries to hit 4096-bit N
p = """
07:79:28:c9:e4:cf:f8:bb:ee:5e:88:73:32:40:88:
08:2f:85:d8:e5:94:d0:c9:9b:ea:29:73:3b:6a:a7:
1d:ff:f7:c6:a2:96:cf:b9:a4:2f:a8:b1:26:ab:ec:
15:ae:14:06:89:ee:67:1e:55:de:89:db:b6:61:41:
5d:ff:0c:09:3f:04:09:85:98:e5:5c:3e:95:a2:5b:
a5:e9:0e:8b:4e:0a:7b:ea:eb:4b:d3:5f:10:fd:2e:
a6:8b:2a:91:21:a2:b5:c0:95:74:67:69:9c:76:95:
5d:8c:a5:db:a7:1a:b1:9b:84:35:95:98:fd:77:41:
05:68:da:84:0b:13:d4:bb:69:05:6e:92:57:f1:6a:
da:ed:7b:57:ae:dc:d0:96:9e:78:e5:35:cf:7c:1e:
3b:26:db:10:0d:aa:85:56:ae:27:ca:ea:f6:ef:b7:
d7:9f:58:02:b3:83:73:18:7d:f2:af:7b:a2:ba:63:
de:79:d5:cf:f1:66:37:2c:cf:8b:39:dc:af:1b:9d:
c2:b5:60:78:f8:a6:45:e9:7d:4c:34:a3:ec:3a:4e:
64:13:a9:de:f0:94:c9:d9:c8:d1:ee:e3:41:09:f1:
94:1c:e0:8e:00:9d:d3:80:51:84:37:2d:08:07:e6:
cb:83:d4:43:de:1c:84:ae:81:1a:04:9c:4b:3b:27:
7e:2a:aa:ac:e4:62:4f:ce:ee:5b:38:d0:cc:48:d3:
5e:17:fe:b7:83:98:4b:9f:8e:55:aa:c6:98:c3:66:
e9:10:eb:e9:28:9d:b8:a2:90:64:4a:24:bc:ea:d7:
0c:19:7f:6b:ae:5f:ea:03:25:0d:1e:ac:e1:7f:98:
18:2f:19:99:81:ae:79:29:67:5d:08:22:f7:59:54:
d0:07:07:30:3b:52:6a:3b:de:11:75:a7:f4:28:fc:
20:f4:be:f3:a0:6e:b4:2a:d6:20:26:21:61:54:02:
61:8c:c6:6b:73:43:46:23:57:5a:39:67:c2:95:2b:
dd:1b:4a:f1:94:50:c4:77:40:a7:30:31:ee:6e:3e:
6b:5a:51:fe:0a:e7:d1:80:98:c6:12:28:00:8a:94:
43:bb:9e:38:bb:24:6f:3c:2f:04:e3:71:4b:48:85:
b8:08:4d:f3:12:73:47:a1:38:2d:85:22:d3:dd:81:
1a:ca:eb:42:df:cf:32:5a:e7:23:a1:98:04:08:a7:
ae:87:7b:4e:ff:c2:35:a1:04:23:c1:9d:ae:46:1e:
43:9f:e8:5d:a7:96:b0:72:09:db:a2:ad:3d:c8:fe:
e4:76:71:15:60:d9:ea:0e:d8:d0:c5:b7:0a:76:3b:
29:c3:94:34:8d:bd:75:cd:9a:d5:ac:f2:38:13:02:
35:bf
"""
p = int.from_bytes(bytes([int(v,16) for v in p.replace("\n", "").split(":")]), "big")
import math
print(math.log2(p))
assert(isPrime(p) and isPrime(p//2))
pp = p // 2 # 4090.? - 1 = 4089.?

e = 0x10001
n = pp * 2 * 23 # 4089.? + 1 + 4.523 = 4094.? or 4095.?
n2 = p * 23
print(math.log2(n))
assert(checkN(n) is None)

while True:
    io = process("python3 ./chall/main.py".split())
    io.sendline(str(n).encode())
    io.readuntil(b" = ");
    r1 = int(io.readline())
    io.readuntil(b" = ");
    r2 = int(io.readline())
    io.readuntil(b" = ");
    flag = bytes.fromhex(io.readline().strip().decode())

    carmichael = lambda *a: lcm(*[v-1 for v in a])

    d1 = pow(e, -1, carmichael(pp, 2, 23))
    k1 = pow(r1, d1, n)

    d2 = pow(e, -1, carmichael(p, 23))
    k2 = pow(r2, d2, n2)

    out = decrypt(k1 ^ k2, flag)
    if b"grey" in out:
        print(out)
        break

    io.close()