import paramiko import os import sys class ModifiedRSAClient: """ Connect to a SSH host using a modified RSA public key and signature. During public key authentication, an SSH client sends its public key to the SSH host. If this public key is a certificate, the signature of the certificate is verified by OpenSSH. This class allows for modification of the public key and signature in the certificate parsed by OpenSSH. """ def __init__(self, host, port, username, user_privkey_file="./ssh_user", user_cert_file="./ssh_user-cert.pub", ca_privkey_file="./ssh_user_ca"): self.host = host self.port = port self.username = username if not os.path.isfile(ca_privkey_file): print("CA private key does not exist. Run") print(f"ssh-keygen -t rsa -b 4096 -f {ca_privkey_file}") print(f"ssh-keygen -t rsa -b 4096 -f {user_privkey_file}") print(f"ssh-keygen -s {ca_privkey_file} -I ca {user_privkey_file}") sys.exit(-1) if not os.path.isfile(user_privkey_file): print("User private key does not exist. Run") print(f"ssh-keygen -t rsa -b 4096 -f {user_privkey_file}") print(f"ssh-keygen -s {ca_privkey_file} -I ca {user_privkey_file}") sys.exit(-1) if not os.path.isfile(user_cert_file): print("User certificate does not exist. Run") print(f"ssh-keygen -s {ca_privkey_file} -I ca {user_privkey_file}") sys.exit(-1) self.ca_key = paramiko.RSAKey.from_private_key_file(ca_privkey_file) self.user_key = paramiko.RSAKey.from_private_key_file(user_privkey_file) self.user_key.load_certificate(user_cert_file) self.n = self.ca_key.public_numbers.n self.e = self.ca_key.public_numbers.e self.modlen = (self.n.bit_length() + 7) // 8 self.key_blob = self.user_key.public_blob.key_blob[:] self.n_offset, self.sig_offset = self.get_offsets(self.key_blob, self.n, self.e, self.modlen) def get_offsets(self, blob, n, e, modlen): # Blob contains a RSA public key and sigature. Find the offsets of # the bytes corresponding to n and to sig n_bytes = n.to_bytes(modlen, byteorder="big") n_offset = blob.index(n_bytes) for i in range(len(blob) - modlen, -1, -1): sig_bytes = blob[i:i+modlen] sig = int.from_bytes(sig_bytes, byteorder="big") msg = pow(sig, e, n) msg_bytes = msg.to_bytes(modlen, byteorder="big") # Does it have valid PKCS padding? if msg_bytes[:8] == b"\x00\x01\xff\xff\xff\xff\xff\xff": sig_offset = i break else: raise "Signature offset not found." return n_offset, sig_offset def query(self, n: bytes = None, sig: bytes = None): assert n is None or len(n) == self.modlen assert sig is None or len(sig) == self.modlen modlen = self.modlen key_blob = bytearray(self.key_blob) if n is not None: n_offset = self.n_offset key_blob[n_offset:n_offset+modlen] = n if sig is not None: sig_offset = self.sig_offset key_blob[sig_offset:sig_offset+modlen] = sig pkey = self.user_key pub_blob = paramiko.PublicBlob( self.user_key.public_blob.key_type, bytes(key_blob), self.user_key.public_blob.comment ) pkey.public_blob = pub_blob client = paramiko.SSHClient() policy = paramiko.AutoAddPolicy() client.set_missing_host_key_policy(policy) try: client.connect(self.host, self.port, username=self.username, pkey=pkey, allow_agent=False) except paramiko.ssh_exception.AuthenticationException as e: client.close() if __name__ == "__main__": client = ModifiedRSAClient("localhost", 22, "user") new_n = b"\xaa"*512 new_sig=b"\xbb"*512 client.query(n=new_n, sig=new_sig)