#include #include #include #include "moepay.h" #include #include #include enum { KiB = 1024, Ibuf = 4 * KiB, Obuf = 4 * KiB, Stack = 8 * KiB, Timeout = 5, Threshold = 128, }; enum { TLSFD, TXTFD, }; typedef struct Tls Tls; struct Tls { struct pollfd pfds[2]; struct { atomic_flag f; ARendez r; } go; mbedtls_ssl_context ssl; mbedtls_net_context net; byte ibuf[Ibuf]; byte obuf[Obuf]; }; static int _p( int e ) { char buf[ERRMAX]; if (e < 0) { mbedtls_strerror(e, buf, sizeof(buf)); werrstr("%s", buf); } return e; } static int closeintr( int fd ) { while (close(fd) < 0) { if (errno != EINTR) { return -1; } } return 0; } /* * We cannot trust that mbedtls actually handles async shit correctly, so we * need to spawn two threads to deal with both input and output. * * Conversely, we cannot trust mbedtls to synchronize between threads, * presumably because it was written by absolute retards. So we now actually * need to trust their half-baked async code and waste tonnes of syscalls * polling sockets and writing a horrible state machine. Thanks mbedtls. */ static int setup( Tls *ctx ) { sigset_t set; /* try to mask all signals since we skimp on the stack */ sigfillset(&set); pthread_sigmask(SIG_BLOCK, &set, nil); /* wait for creator to greenlight us */ if (!atomic_flag_test_and_set(&ctx->go.f)) { if (arendez(&ctx->go.r, nil) != nil) { return -1; } } return 0; } static void relay( void *arg ) { Tls *ctx = arg; size_t ioff = 0, ooff = 0; size_t in = 0, out = 0; int tlsmask = ~0, txtmask = ~0; int shut = 0; ssize_t r; if (setup(ctx) != 0) { return; } while (1) { /* check for shutdown */ if (!(tlsmask & POLLIN) && !(shut & SHUT_WR) && in == 0) { shutdown(ctx->pfds[TXTFD].fd, SHUT_WR); shut |= SHUT_WR; } if (!(tlsmask & POLLOUT) && !(shut & SHUT_RD) && out == 0) { shutdown(ctx->pfds[TXTFD].fd, SHUT_RD); shut |= SHUT_RD; } /* check for complete stop */ if ((!(tlsmask & (POLLIN | POLLOUT)) && in == 0) || (!(txtmask & (POLLIN | POLLOUT)) && out == 0) || !((tlsmask | txtmask) & (POLLIN | POLLOUT))) { break; } /* check if it pays to move the buffer */ if (in < Threshold && ioff > Threshold) { memmove(ctx->ibuf, ctx->ibuf + ioff, in); ioff = 0; } if (out < Threshold && ooff > Threshold) { memmove(ctx->obuf, ctx->obuf + ooff, out); ooff = 0; } /* setup pollfds */ ctx->pfds[0].revents = 0; ctx->pfds[1].revents = 0; ctx->pfds[0].events = 0; ctx->pfds[1].events = 0; if (in > 0) { ctx->pfds[TXTFD].events |= POLLOUT; } if (out > 0) { ctx->pfds[TLSFD].events |= POLLOUT; } if (ioff + in < Ibuf) { ctx->pfds[TLSFD].events |= POLLIN; } if (ooff + out < Obuf) { ctx->pfds[TXTFD].events |= POLLIN; } ctx->pfds[TLSFD].events &= tlsmask; ctx->pfds[TXTFD].events &= txtmask; /* try to poll */ while ((r = poll(ctx->pfds, 2, -1)) < 0) { if (errno != EINTR) { nsleep((uvlong)1000000 * Timeout); ctx->pfds[0].revents = ctx->pfds[0].events & (POLLIN | POLLOUT); ctx->pfds[1].revents = ctx->pfds[1].events & (POLLIN | POLLOUT); break; } } if (ctx->pfds[TLSFD].revents & POLLHUP) { tlsmask &= ~(POLLIN | POLLOUT); } if (ctx->pfds[TXTFD].revents & POLLHUP) { txtmask &= ~(POLLIN | POLLOUT); } /* drain data between sockets */ if (ctx->pfds[TLSFD].revents & POLLIN) { r = mbedtls_ssl_read(&ctx->ssl, ctx->ibuf + ioff, sizeof(ctx->ibuf) - ioff - in); if (r <= 0) { if (r != MBEDTLS_ERR_SSL_WANT_READ || r != MBEDTLS_ERR_SSL_WANT_WRITE) { tlsmask &= ~POLLIN; } } else { in += (size_t)r; } } if (ctx->pfds[TLSFD].revents & POLLOUT) { r = mbedtls_ssl_write(&ctx->ssl, ctx->obuf + ooff, out); if (r <= 0) { if (r != MBEDTLS_ERR_SSL_WANT_READ || r != MBEDTLS_ERR_SSL_WANT_WRITE) { tlsmask &= ~POLLOUT; out = 0; } } else { ooff += (size_t)r; out -= (size_t)r; } if (out == 0) { ooff = 0; } } if (ctx->pfds[TXTFD].revents & POLLIN) { r = recv(ctx->pfds[TXTFD].fd, ctx->obuf + ooff, sizeof(ctx->obuf) - ooff - out, MSG_DONTWAIT); if (r <= 0) { if (r == 0 || (errno != EAGAIN && errno != EWOULDBLOCK)) { txtmask &= ~POLLIN; } } else { out += (size_t)r; } } if (ctx->pfds[TXTFD].revents & POLLOUT) { r = send(ctx->pfds[TXTFD].fd, ctx->ibuf + ioff, in, MSG_DONTWAIT); if (r <= 0) { if (r == 0 || (errno != EAGAIN && errno != EWOULDBLOCK)) { txtmask &= ~POLLOUT; in = 0; } } else { ioff += (size_t)r; in -= (size_t)r; } if (in == 0) { ioff = 0; } } } closeintr(ctx->pfds[0].fd); closeintr(ctx->pfds[1].fd); mbedtls_ssl_free(&ctx->ssl); } int mbedify( int fd, const void *conf, const char *hostname ) { int fds[2] = { -1, -1 }; int tlsfd = -1; int e, flags; Tls *ctx; ctx = mallocz(sizeof(*ctx)); if (!ctx) { return -1; } /* * Alright, so we need to create a pair of fd0<->fd1, then dup the given * fd and set FD_CLOEXEC on it. Then we dup one end of our pair on top of * the original fd. That way we'll have replaced the original fd with a new * one that speaks plaintext and we can read/write from the other end and * relay it to the original tls fd. */ if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, fds) != 0) { goto errout; } if ((tlsfd = fcntl(fd, F_DUPFD_CLOEXEC, 0)) < 0) { goto errout; } if ((flags = fcntl(tlsfd, F_GETFL)) < 0) { goto errout; } if (fcntl(tlsfd, F_SETFL, flags | O_NONBLOCK) < 0) { goto errout; } if ((flags = fcntl(fd, F_GETFD)) < 0) { goto errout; } flags = (flags & FD_CLOEXEC) ? O_CLOEXEC : 0; ctx->pfds[TLSFD].fd = tlsfd; ctx->pfds[TXTFD].fd = fds[0]; mbedtls_ssl_init(&ctx->ssl); if (_p(mbedtls_ssl_setup(&ctx->ssl, conf)) != 0) { goto errout; } if (hostname) { if (_p(mbedtls_ssl_set_hostname(&ctx->ssl, hostname)) != 0) { goto errout; } } mbedtls_net_init(&ctx->net); ctx->net.fd = tlsfd; mbedtls_ssl_set_bio(&ctx->ssl, &ctx->net, mbedtls_net_send, mbedtls_net_recv, mbedtls_net_recv_timeout); if (threadcreate(relay, ctx, Stack) < 0) { goto errout; } /* final dup, replaces the original fd with our socket pair must go last */ if (dup3(fds[1], fd, flags) < 0) { arendez(&ctx->go.r, (void *)0xd1e); goto errout; } /* signal thread to start */ if (atomic_flag_test_and_set(&ctx->go.f)) { arendez(&ctx->go.r, nil); } closeintr(fds[1]); return 0; errout: e = errno; if (fds[0] >= 0) { closeintr(fds[0]); } if (fds[1] >= 0) { closeintr(fds[1]); } if (tlsfd >= 0) { closeintr(tlsfd); } mbedtls_ssl_free(&ctx->ssl); free(ctx); errno = e; return -1; }