added gnutls backend and moved backend code into new files
[ircu2.10.12-pk.git] / ircd / ssl.gnutls.c
1 /*
2  * IRC - Internet Relay Chat, ircd/ssl.gnutls.c
3  * Copyright (C) 2015 pk910 (Philipp Kreil)
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation; either version 1, or (at your option)
8  * any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18  */
19 /** @file
20  * @brief Implementation of functions for handling ssl connections
21  * @version $Id$
22  */
23 #include <gnutls/gnutls.h>
24
25 #ifndef GNUTLS_SEC_PARAM_LEGACY
26 #define GNUTLS_SEC_PARAM_LEGACY 2
27 #endif
28  
29 struct SSLPendingConections {
30   struct SSLConnection *connection;
31   struct SSLPendingConections *next;
32   
33   void *data;
34   enum SSLDataType datatype;
35 };
36
37 struct SSLPendingConections *firstPendingConection = NULL;
38 int ssl_is_initialized = 0;
39 static gnutls_dh_params_t ssl_dh_params;
40 static unsigned int ssl_dh_params_bits;
41
42 static int ssl_generate_dh_params() {
43   ssl_dh_params_bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH, GNUTLS_SEC_PARAM_LEGACY);
44   gnutls_dh_params_init(&ssl_dh_params);
45   gnutls_dh_params_generate2(ssl_dh_params, ssl_dh_params_bits);
46   return 0;
47 }
48
49 static void ssl_init() {
50   if(ssl_is_initialized)
51     return;
52   ssl_is_initialized = 1;
53   int res;
54   res = gnutls_global_init();
55   
56   if(res != GNUTLS_E_SUCCESS) {
57     //TODO: Log Errors
58   }
59 }
60
61 void ssl_free_connection(struct SSLConnection *connection) {
62   gnutls_bye(connection->session, GNUTLS_SHUT_RDWR);
63   if(connection->credentials)
64     gnutls_certificate_free_credentials(connection->credentials);
65   gnutls_deinit(connection->session);
66   free(connection);
67 }
68
69 void ssl_free_listener(struct SSLListener *listener) {
70   gnutls_certificate_free_credentials(listener->credentials);
71   gnutls_priority_deinit(listener->priority);
72   free(listener);
73 }
74
75 static void ssl_handshake_completed(struct SSLConnection *connection, int success) {
76   struct SSLPendingConections *pending, *lastPending = NULL;
77   for(pending = firstPendingConection; pending; pending = pending->next) {
78     if(pending->connection == connection) {
79       if(lastPending)
80         lastPending->next = pending->next;
81       else
82         firstPendingConection = pending->next;
83       switch(pending->datatype) {
84         case SSLData_Client: {
85             struct Client *cptr = (struct Client *) pending->data;
86             if(success) {
87               if(FlagHas(&connection->flags, SSLFLAG_INCOMING))
88                 start_auth(cptr);
89               else if(!completed_connection(cptr))
90                 exit_client_msg(cptr, cptr, &me, "Registration failed.");
91             } else
92               exit_client_msg(cptr, cptr, &me, "SSL Handshake failed.");
93           }
94           break;
95       }
96       free(pending);
97     }
98     lastPending = pending;
99   }
100 }
101
102 static int ssl_handshake_outgoing(struct SSLConnection *connection) {
103   int ret = gnutls_handshake(connection->session);
104   FlagClr(&connection->flags, SSLFLAG_HANDSHAKE_R);
105   FlagClr(&connection->flags, SSLFLAG_HANDSHAKE_W);
106   
107   if(ret < 0) {
108     if(gnutls_error_is_fatal(ret) == 0) {
109       if(gnutls_record_get_direction(connection->session))
110         FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_W);
111       else
112         FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_R);
113       return 1;
114     } else {
115       
116       return 0;
117     }
118   } else {
119     FlagClr(&connection->flags, SSLFLAG_HANDSHAKE);
120     FlagSet(&connection->flags, SSLFLAG_READY);
121     ssl_handshake_completed(connection, 1);
122     return 0;
123   }
124 }
125
126 struct SSLConnection *ssl_create_connect(int fd, void *data, enum SSLDataType datatype) {
127   struct SSLConnection *connection = malloc(sizeof(*connection));
128   struct SSLPendingConections *pending = NULL;
129   
130   if(!connection)
131     return NULL;
132   
133   if(!ssl_is_initialized)
134     ssl_init();
135   
136   gnutls_certificate_allocate_credentials(&connection->credentials);
137   gnutls_init(&connection->session, GNUTLS_CLIENT);
138   
139   gnutls_priority_set_direct(connection->session, "SECURE128:+SECURE192:-VERS-TLS-ALL:+VERS-TLS1.2", NULL);
140   gnutls_credentials_set(connection->session, GNUTLS_CRD_CERTIFICATE, connection->credentials);
141   
142   gnutls_transport_set_ptr(connection->session, (gnutls_transport_ptr_t)fd);
143   //gnutls_handshake_set_timeout(connection->session, 30);
144   
145   FlagSet(&connection->flags, SSLFLAG_OUTGOING);
146   FlagSet(&connection->flags, SSLFLAG_HANDSHAKE);
147   
148   pending = malloc(sizeof(*pending));
149   if(!pending) {
150     goto ssl_create_connect_failed;
151   }
152   pending->connection = connection;
153   pending->next = firstPendingConection;
154   firstPendingConection = pending;
155   
156   pending->data = data;
157   pending->datatype = datatype;
158   
159   ssl_handshake_outgoing(connection);
160   
161   return connection;
162 ssl_create_connect_failed:
163   free(connection);
164   return NULL;
165 }
166
167 void ssl_start_handshake_connect(struct SSLConnection *connection) {
168   ssl_handshake_outgoing(connection);
169 }
170
171 struct SSLListener *ssl_create_listener() {
172   if(!ssl_is_initialized)
173     ssl_init();
174   
175   struct SSLListener *listener = calloc(1, sizeof(*listener));
176   
177   gnutls_priority_init(&listener->priority, "SECURE128:+SECURE192:-VERS-TLS-ALL:+VERS-TLS1.2", NULL);
178   gnutls_certificate_allocate_credentials(&listener->credentials);
179   
180   char *certfile = conf_get_local()->sslcertfile;
181   char *keyfile = conf_get_local()->sslkeyfile;
182   char *cafile = conf_get_local()->sslcafile;
183   
184   if(!certfile) {
185     goto ssl_create_listener_failed;
186   }
187   if(!keyfile) {
188     keyfile = certfile;
189   }
190   
191   /* load certificate */
192   if(gnutls_certificate_set_x509_key_file(listener->credentials, certfile, keyfile, GNUTLS_X509_FMT_PEM) < 0) {
193     goto ssl_create_listener_failed;
194   }
195   /* load cafile */
196   //TODO: ca file check!
197   /*
198   if(cafile && cafile[0] && SSL_CTX_load_verify_locations(listener->context, cafile, NULL) <= 0) {
199     goto ssl_create_listener_failed;
200   }
201   */
202   
203   gnutls_certificate_set_dh_params(listener->credentials, ssl_dh_params);
204   
205   FlagSet(&listener->flags, SSLFLAG_READY);
206   return listener;
207 ssl_create_listener_failed:
208   free(listener);
209   return NULL;
210 }
211
212 struct SSLConnection *ssl_start_handshake_listener(struct SSLListener *listener, int fd, void *data, enum SSLDataType datatype) {
213   if(!listener)
214     return NULL;
215   struct SSLPendingConections *pending = NULL;
216   struct SSLConnection *connection = malloc(sizeof(*connection));
217   
218   gnutls_init(&connection->session, GNUTLS_SERVER);
219   gnutls_priority_set(connection->session, listener->priority);
220   gnutls_credentials_set(connection->session, GNUTLS_CRD_CERTIFICATE, listener->credentials);
221   connection->credentials = NULL;
222   gnutls_dh_set_prime_bits(connection->session, ssl_dh_params_bits);
223   gnutls_certificate_server_set_request(connection->session, GNUTLS_CERT_IGNORE);
224   
225   gnutls_transport_set_ptr(connection->session, (gnutls_transport_ptr_t)fd);
226   
227   FlagSet(&connection->flags, SSLFLAG_INCOMING);
228   FlagSet(&connection->flags, SSLFLAG_HANDSHAKE);
229   FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_R);
230   
231   pending = malloc(sizeof(*pending));
232   if(!pending) {
233     goto ssl_start_handshake_listener_failed;
234   }
235   pending->connection = connection;
236   pending->next = firstPendingConection;
237   firstPendingConection = pending;
238   
239   pending->data = data;
240   pending->datatype = datatype;
241   
242   ssl_handshake_outgoing(connection);
243   return connection;
244 ssl_start_handshake_listener_failed:
245   free(connection);
246   return NULL;
247 }
248
249 IOResult ssl_recv_decrypt(struct SSLConnection *connection, char *buf, unsigned int buflen, unsigned int *len) {
250   if(FlagHas(&connection->flags, SSLFLAG_HANDSHAKE)) {
251     ssl_handshake_outgoing(connection);
252     return IO_BLOCKED;
253   }
254   
255   int ret = gnutls_record_recv(connection->session, buf, buflen);
256   
257   if(ret == 0) {
258     return IO_FAILURE;
259   } else if(ret < 0 && gnutls_error_is_fatal(ret) == 0) {
260     if(ret == GNUTLS_E_REHANDSHAKE) {
261       FlagSet(&connection->flags, SSLFLAG_HANDSHAKE);
262       ssl_handshake_outgoing(connection);
263       return IO_BLOCKED;
264     }
265   } else if(ret < 0) {
266     return IO_FAILURE;
267   }
268   *len = ret;
269   return IO_SUCCESS;
270 }
271
272 static ssize_t ssl_writev(gnutls_session_t ssl, const struct iovec *vector, int count) {
273   char *buffer;
274   register char *bp;
275   size_t bytes, to_copy;
276   int i;
277
278   /* Find the total number of bytes to be written.  */
279   bytes = 0;
280   for (i = 0; i < count; ++i)
281     bytes += vector[i].iov_len;
282
283   /* Allocate a temporary buffer to hold the data.  */
284   buffer = (char *) alloca (bytes);
285
286   /* Copy the data into BUFFER.  */
287   to_copy = bytes;
288   bp = buffer;
289   for (i = 0; i < count; ++i) {
290     size_t copy = ((vector[i].iov_len) > (to_copy) ? (to_copy) : (vector[i].iov_len));
291     memcpy ((void *) bp, (void *) vector[i].iov_base, copy);
292     bp += copy;
293     to_copy -= copy;
294     if (to_copy == 0)
295       break;
296   }
297   return gnutls_record_send(ssl, buffer, bytes);
298 }
299
300 IOResult ssl_send_encrypt_plain(struct SSLConnection *connection, char* buf, int len) {
301   return gnutls_record_send(connection->session, buf, len);
302 }
303
304 IOResult ssl_send_encrypt(struct SSLConnection *connection, struct MsgQ* buf, unsigned int *count_in, unsigned int *count_out) {
305   int res;
306   int count;
307   struct iovec iov[IOV_MAX];
308
309   assert(0 != buf);
310   assert(0 != count_in);
311   assert(0 != count_out);
312
313   *count_in = 0;
314   count = msgq_mapiov(buf, iov, IOV_MAX, count_in);
315   res = ssl_writev(connection->session, iov, count);
316
317   if(res == 0) {
318     *count_out = 0;
319     return IO_FAILURE;
320   } else if(res < 0 && gnutls_error_is_fatal(res) == 0) {
321     if(res == GNUTLS_E_REHANDSHAKE) {
322       FlagSet(&connection->flags, SSLFLAG_HANDSHAKE);
323       ssl_handshake_outgoing(connection);
324       return IO_BLOCKED;
325     }
326   } else if(res < 0) {
327     *count_out = 0;
328     return IO_FAILURE;
329   }
330   *count_out = (unsigned) res;
331   return IO_SUCCESS;
332 }
333
334 int ssl_connection_flush(struct SSLConnection *connection) {
335   if(connection) {
336     if(ssl_handshake(connection)) {
337       return ssl_handshake_outgoing(connection);
338     }
339   } else {
340     struct SSLPendingConections *curr, *last = NULL, *next;
341     for(curr = firstPendingConection; curr; curr = next) {
342       next = curr->next;
343       if(!ssl_connection_flush(curr->connection)) {
344         // connection is already in auth process here, curr is freed!
345         continue;
346       }
347       last = curr;
348     }
349   }
350   return 0;
351 }
352
353 const char* ssl_get_cipher(struct SSLConnection *connection) {
354   if(!connection)
355     return NULL;
356   static char buf[401];
357   const char *kx_name, *cipher_name, *mac_name;
358   unsigned int len, i;
359   char *dest;
360
361   kx_name = gnutls_kx_get_name(gnutls_kx_get(connection->session));
362   cipher_name = gnutls_cipher_get_name(gnutls_cipher_get(connection->session));
363   mac_name = gnutls_mac_get_name(gnutls_mac_get(connection->session));
364
365   if(!kx_name || !cipher_name || !mac_name) {
366     return "<invalid>";
367   }
368
369   len = strlen(kx_name) + strlen(cipher_name) + strlen(mac_name);
370   if(len > 395) {
371     return "<invalid>";
372   }
373   else {
374     dest = buf;
375     i = 0;
376     while((*dest++ = kx_name[i++])) /* empty */ ;
377     *(dest - 1) = '-';
378     i = 0;
379     while((*dest++ = cipher_name[i++])) /* empty */ ;
380     *(dest - 1) = '-';
381     i = 0;
382     while((*dest++ = mac_name[i++])) /* empty */ ;
383     return buf;
384   }
385 }