# RSA Prime Generation

In [None]:
def _try_composite(a, d, n, s):
    if pow(a, d, n) == 1:
        return False
    for i in range(s):
        if pow(a, 2**i * d, n) == n-1:
            return False
    return True # n  is definitely composite

def is_prime(n, _precision_for_huge_n=16):
    if n in _known_primes:
        return True
    if any((n % p) == 0 for p in _known_primes) or n in (0, 1):
        return False
    d, s = n - 1, 0
    while not d % 2:
        d, s = d >> 1, s + 1
    # Returns exact according to http://primes.utm.edu/prove/prove2_3.html
    if n < 1373653: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3))
    if n < 25326001: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5))
    if n < 118670087467: 
        if n == 3215031751: 
            return False
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7))
    if n < 2152302898747: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11))
    if n < 3474749660383: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11, 13))
    if n < 341550071728321: 
        return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11, 13, 17))
    # otherwise
    return not any(_try_composite(a, d, n, s) 
                   for a in _known_primes[:_precision_for_huge_n])

_known_primes = [2, 3]
_known_primes += [x for x in range(5, 1000, 2) if is_prime(x)]

In [None]:
from secrets import token_bytes as rand_bytes

def generate_prime(length = 32):
  base = int.from_bytes(rand_bytes(length), 'big')

  if (base % 2 == 0):
    base += 1
  
  while not is_prime(base, 64):
    base += 2

  return base

In [None]:
print(is_prime(generate_prime(100)))

# RSA Algorithm

### Psuedocode
- Generate p and q, two random primes
- Calculate n = p * q
- Totient = (p - 1) * (q - 1) = n - p - q + 1
- Set e = 3
- Calculate d = invmod(e, totient)
- Public Key: (e, n); Private Key: (d, n)


## Code

In [None]:
from math import gcd

class rsa:
  def __init__(self):
    self.e = 3
    self.totient = 3
    while (gcd(self.e, self.totient) != 1):
      self.generate_primes()
    self.d = pow(self.e, -1, self.totient)
    assert(self.d * self.e % self.totient == 1)

    self.secret_message = b"this should not be decrypted"
    self.secret_message_integer = int.from_bytes(self.secret_message, 'big')

  def generate_primes(self):
    p = generate_prime(32)
    q = generate_prime(32)
    self.n = p * q
    self.totient = (p - 1) * (q - 1)
    self.e = 3

  def encrypt (self, message: bytes, endianess = 'big'):
    message_integer = int.from_bytes(message, endianess)
    return (pow(message_integer, self.e, self.n))

  def decrypt (self, cipher: int, endianess = 'big'):
    message = (pow(cipher, self.d, self.n))
    return 0 if message == self.secret_message_integer else message
  
  def public_key(self):
    return (self.e, self.n)

In [None]:
rsa_client = rsa()

message = b"hello world"

rsa_client.decrypt(rsa_client.encrypt(message)).to_bytes(len(message), 'big')

# Small m attacks

In [None]:
!pip install gmpy2

In [None]:
import gmpy2

m, is_true_root = gmpy2.iroot(c, e)
if is_true_root:
    print("Message: {}".format(bytearray.fromhex(format(m, 'x')).decode().strip()))

# Unpadded Message Recovery Attack



In [None]:
rsa_client = rsa()

secret_message = rsa_client.secret_message
ciphertext = rsa_client.encrypt(secret_message)

In [None]:
rsa_client.decrypt(ciphertext)

In [None]:
S = 10
exp, mod = rsa_client.public_key()
new_ciphertext = (pow(S, exp, mod) * ciphertext) % mod

new_plaintext = rsa_client.decrypt(new_ciphertext)
plaintext = (new_plaintext * pow(S, -1, mod)) % mod

plaintext

In [None]:
plaintext.to_bytes(50, 'big').strip(b'\x00')

# Flipping Bits [Square CTF 2018 C2]

You have two captured ciphertexts. The public key is ``(e1, n)``. But, due to a transient bit flip, the second ciphertext was encrypted with a faulty public key: ``(e2, n)``. Recover the plaintexts.

(The algorithm is RSA.)

```
ct1:  13981765388145083997703333682243956434148306954774120760845671024723583618341148528952063316653588928138430524040717841543528568326674293677228449651281422762216853098529425814740156575513620513245005576508982103360592761380293006244528169193632346512170599896471850340765607466109228426538780591853882736654
ct2:  79459949016924442856959059325390894723232586275925931898929445938338123216278271333902062872565058205136627757713051954083968874644581902371182266588247653857616029881453100387797111559677392017415298580136496204898016797180386402171968931958365160589774450964944023720256848731202333789801071962338635072065
e1:  13
e2:  15
modulus:  103109065902334620226101162008793963504256027939117020091876799039690801944735604259018655534860183205031069083254290258577291605287053538752280231959857465853228851714786887294961873006234153079187216285516823832102424110934062954272346111907571393964363630079343598511602013316604641904852018969178919051627
```

## Code

In [None]:
def bezout(a, b, x = 0, prev_x = 1, y = 1, prev_y = 0):
	""" Calculate the BÃ©zout's identity of 'a' and 'b' recursively
		Using the extended euclidean algorithm
	"""

	# 'a' has to be greater than 'b'
	if b > a:
		a, b = b, a

	# calculate the remainder of a/b
	remainder = a % b

	# if remainder is 0, stop here : gcd found
	if remainder == 0:
		return b, x, y

	# else, update x and y, and continue
	quotient = a // b
	prev_x, prev_y, x, y = x, y, quotient*x + prev_x, quotient*y + prev_y
	return bezout(b, remainder, x, prev_x, y, prev_y)

In [None]:
ct1 = 13981765388145083997703333682243956434148306954774120760845671024723583618341148528952063316653588928138430524040717841543528568326674293677228449651281422762216853098529425814740156575513620513245005576508982103360592761380293006244528169193632346512170599896471850340765607466109228426538780591853882736654
ct2 = 79459949016924442856959059325390894723232586275925931898929445938338123216278271333902062872565058205136627757713051954083968874644581902371182266588247653857616029881453100387797111559677392017415298580136496204898016797180386402171968931958365160589774450964944023720256848731202333789801071962338635072065

e1 = 13
e2 = 15

mod = 103109065902334620226101162008793963504256027939117020091876799039690801944735604259018655534860183205031069083254290258577291605287053538752280231959857465853228851714786887294961873006234153079187216285516823832102424110934062954272346111907571393964363630079343598511602013316604641904852018969178919051627

In [None]:
gcd, x, y = bezout(e1, e2)

In [None]:
y * e1 - x * e2

In [None]:
plaintext = ((pow(ct1, y, mod) * pow(ct2, -x, mod)) % mod)

In [None]:
plaintext.to_bytes(5000, 'big').strip(b'\x00')

# Problems
- [ ] [miniRSA](https://play.picoctf.org/practice/challenge/73)
- [ ] [Mini RSA](https://play.picoctf.org/practice/challenge/188)
- [ ] [college rowing team](https://play.picoctf.org/practice/challenge/212)
- [ ] [sum-o-primes](https://play.picoctf.org/practice/challenge/310)
- A website that has various problems across several levels of difficulties: https://cryptohack.org/challenges/rsa/