SSL_read及SSL_write支持超时

原始的socket编程中 readwrite 支持超时是很容易实现的,如使用 select 或者 setsockopt 设置读写超时并在 readwrite 出错后根据 errno 判断是否为超时引起。

但是在 SSL 编程中对底层socket调用 select 以及使用 errno 行为是未定义的。

使用 setsockopt 在底层的socket上设置读写后, SSL_readSSL_write 出错会返回ssl错误码 SSL_ERROR_WANT_READSSL_ERROR_WANT_WRITE , 但是被信号中断或者底层SSL需要重新握手也会导致 SSL_readSSL_write 返回同样的ssl错误码。

如果能够将信号屏蔽掉,并启用SSL自动重新握手,就能够实现 SSL_readSSL_write 超时检测。

  • 屏蔽信号

    忽略应用产生的信号,如:

    signal(SIGPIPE, SIG_IGN);
    signal(SIGCHLD, SIG_IGN);
    
  • 在底层socket上设置超时
    struct timeval tv;
    tv.tv_sec  = 10;
    tv.tv_usec = 0;
    setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char*)&tv, sizeof(struct timeval));
    setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char*)&tv, sizeof(struct timeval));
    
  • 启用自动重新握手
    SSL_CTX_set_mode(ctx, SSL_MODE_AUTO_RETRY);
    
  • SSL_readSSL_write 判断是否超时出错
    int readed = SSL_read(ssl, data, size);
    if (readed <= 0) {
        if (SSL_get_error(ssl, readed) == SSL_ERROR_WANT_READ) {
            // timeout
        } else {
            // error
        }
    }
    
    int writed = SSL_write(ssl, data, size);
    if (writed <= 0) {
        if (SSL_get_error(ssl, writed) == SSL_ERROR_WANT_WRITE) {
            // timeout
        } else {
            // error
        }
    }
    

ssl 客户端示例代码

这个示例包括建立连接、读、写,以及超时设置、服务器证书验证。

#include <arpa/inet.h>
#include <netinet/in.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <openssl/x509_vfy.h>
#include <openssl/x509v3.h>
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
#include <strings.h>
#include <sys/socket.h>
#include <unistd.h>

#define SSL_CLIENT_CAFILE "/etc/ssl/certs/ca-certificates.crt"
#define SSL_CLIENT_CAPATH "/etc/ssl/certs/"

int ssl_client_connect(uint32_t ip, uint16_t port, SSL** ssl, SSL_CTX** ctx,
                       uint32_t timeout, const char* verify_host) {
  int sock;
  struct sockaddr_in dest;

  if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
    fprintf(stderr, "create remote socket fd failed!");
    return -1;
  }

  bzero(&dest, sizeof(dest));
  dest.sin_family = AF_INET;
  dest.sin_port = htons(port);
  dest.sin_addr.s_addr = ip;

  char ip_string[INET_ADDRSTRLEN] = {'\0'};
  inet_ntop(AF_INET, &dest.sin_addr, ip_string, sizeof(ip_string));

  if (connect(sock, (struct sockaddr*)&dest, sizeof(dest)) != 0) {
    fprintf(stderr, "connect to %s:%d failed: %s", ip_string, port,
            strerror(errno));
    close(sock);
    return -1;
  }

  fprintf(stderr, "tcp connect to %s:%d success", ip_string, port);

  /* 设置send/recv的超时时间 */
  struct timeval tv = {'\0'};
  tv.tv_sec = timeout;
  tv.tv_usec = 0;
  int succ = setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char*)&tv,
                        sizeof(struct timeval));
  if (succ != 0) {
    fprintf(stderr, "set send timeout failed: %s", strerror(errno));
    close(sock);
    return -1;
  }
  succ = setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char*)&tv,
                    sizeof(struct timeval));
  if (succ != 0) {
    fprintf(stderr, "set recv timeout failed: %s", strerror(errno));
    close(sock);
    return -1;
  }

  /* 基于 ctx 产生一个新的 SSL */
  *ctx = SSL_CTX_new(SSLv23_client_method());
  if (NULL == *ctx) {
    fprintf(stderr, "new ssl ctx failed");
    ERR_print_errors_fp(stderr);
    close(sock);
    return -1;
  }

  /* 启用自动重新握手,禁止SSL_read或SSL_write因SSL重新握手提前返回,导致无法区分是否为recv超时.
   */
  if (!(SSL_CTX_set_mode(*ctx, SSL_MODE_AUTO_RETRY) & SSL_MODE_AUTO_RETRY)) {
    fprintf(stderr, "set ssl auto retry mode failed");
    ERR_print_errors_fp(stderr);
    close(sock);
    return -1;
  }

  /* 验证服务器验书 */
  if (verify_host) {
    if (!SSL_CTX_load_verify_locations(*ctx, SSL_CLIENT_CAFILE,
                                       SSL_CLIENT_CAPATH)) {
      fprintf(
          stderr,
          "failed to load certificate verify locations. CAfile(%s) CApath(%s)",
          SSL_CLIENT_CAFILE, SSL_CLIENT_CAPATH);
      ERR_print_errors_fp(stderr);
      close(sock);
      return -1;
    }

    /* 验证服务器主机名称,参考:https://wiki.openssl.org/index.php/Hostname_validation
     */
    X509_VERIFY_PARAM* param = SSL_CTX_get0_param(*ctx);
    /*
      X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS 选项会导致 Hostname mismatch
      错误,注掉先 X509_VERIFY_PARAM_set_hostflags(param,
      X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS);
    */
    X509_VERIFY_PARAM_set1_host(param, verify_host, 0);

    SSL_CTX_set_verify(*ctx, SSL_VERIFY_PEER, NULL);
  }

  *ssl = SSL_new(*ctx);
  if (NULL == *ssl) {
    fprintf(stderr, "new ssl failed");
    ERR_print_errors_fp(stderr);
    SSL_CTX_free(*ctx);
    *ctx = NULL;
    close(sock);
    return -1;
  }

  if (1 != SSL_set_fd(*ssl, sock)) {
    fprintf(stderr, "set ssl fd failed");
    ERR_print_errors_fp(stderr);
    SSL_free(*ssl);
    *ssl = NULL;
    SSL_CTX_free(*ctx);
    *ctx = NULL;
    close(sock);
    return -1;
  }

  /* 建立 SSL 连接 */
  int ret = SSL_connect(*ssl);
  if (ret <= 0) {
    char error[128] = {'\0'};
    ERR_error_string_n(ERR_get_error(), error, sizeof(error));
    fprintf(stderr, "ssl connect to %s:%d failed(%d) error(%s) errno(%d)",
            ip_string, port, SSL_get_error(*ssl, ret), error, errno);
    if (verify_host) {
      long verify_result = SSL_get_verify_result(*ssl);
      if (verify_result != X509_V_OK) {
        fprintf(stderr, "ssl certificate error(%s)",
                X509_verify_cert_error_string(verify_result));
      }
    }
    SSL_free(*ssl);
    *ssl = NULL;
    SSL_CTX_free(*ctx);
    *ctx = NULL;
    close(sock);
    return -1;
  }

  fprintf(stderr, "ssl connect to %s:%d success", ip_string, port);

  return sock;
}

int ssl_client_read(SSL* ssl, char* data, uint32_t nbytes, uint32_t timeout) {
  int readed;
  int remaining = nbytes;

  while (remaining) {
    readed = SSL_read(ssl, data + (nbytes - remaining), remaining);
    if (readed <= 0) {
      fprintf(stderr, "ssl read error(%d) readed(%d) errno(%d)",
              SSL_get_error(ssl, readed), readed, errno);
      return -1;
    }
    remaining -= readed;
  }

  return 0;
}

int ssl_client_write(SSL* ssl, char* data, uint32_t nbytes, uint32_t timeout) {
  int writed;
  int remaining = nbytes;

  while (remaining) {
    writed = SSL_write(ssl, data + (nbytes - remaining), remaining);
    if (writed <= 0) {
      fprintf(stderr, "ssl write error(%d) writed(%d) errno(%d)",
              SSL_get_error(ssl, writed), writed, errno);
      return -1;
    }
    remaining -= writed;
  }

  return 0;
}