e407bd4dcc8cf2934e8521b997addc695b971aca
[ircu2.10.12-pk.git] / ircd / ssl.c
1 /*
2  * IRC - Internet Relay Chat, ircd/ssl.c
3  * Copyright (C) 2015 pk910 (Philipp Kreil)
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation; either version 1, or (at your option)
8  * any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program; if not, write to the Free Software
17  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
18  */
19 /** @file
20  * @brief Implementation of functions for handling local clients.
21  * @version $Id$
22  */
23 #include "config.h"
24
25 #include "client.h"
26 #include "ssl.h"
27 #include "class.h"
28 #include "ircd.h"
29 #include "ircd_features.h"
30 #include "ircd_log.h"
31 #include "ircd_reply.h"
32 #include "list.h"
33 #include "msgq.h"
34 #include "numeric.h"
35 #include "s_conf.h"
36 #include "s_debug.h"
37 #include "send.h"
38 #include "struct.h"
39
40 /* #include <assert.h> -- Now using assert in ircd_log.h */
41 #include <string.h>
42
43 #ifndef IOV_MAX
44 #define IOV_MAX 16      /**< minimum required length of an iovec array */
45 #endif
46
47 #if defined(HAVE_OPENSSL_SSL_H)
48
49 static struct SSLPendingConections {
50   struct SSLConnection *connection;
51   struct SSLPendingConections *next;
52   
53   void *data;
54   enum SSLDataType datatype;
55 };
56
57 struct SSLPendingConections *firstPendingConection = NULL;
58 int ssl_is_initialized = 0;
59
60 static void ssl_init() {
61   if(ssl_is_initialized)
62     return;
63   ssl_is_initialized = 1;
64         SSL_library_init();
65         OpenSSL_add_all_algorithms(); /* load & register all cryptos, etc. */
66         SSL_load_error_strings();
67 }
68
69 void ssl_free_connection(struct SSLConnection *connection) {
70   SSL_CTX *context = NULL;
71   if(FlagHas(&connection->flags, SSLFLAG_OUTGOING)) {
72     struct SSLOutConnection *outconn = (struct SSLOutConnection *)connection;
73     context = outconn->context;
74   }
75   SSL_shutdown(connection->session);
76   SSL_free(connection->session);
77   if(context)
78     SSL_CTX_free(context);
79   free(connection);
80 }
81
82 void ssl_free_listener(struct SSLListener *listener) {
83   SSL_CTX_free(listener->context);
84   free(listener);
85 }
86
87 static void ssl_handshake_completed(struct SSLConnection *connection, int success) {
88   struct SSLPendingConections *pending, *lastPending = NULL;
89   for(pending = firstPendingConection; pending; pending = pending->next) {
90     if(pending->connection == connection) {
91       if(lastPending)
92         lastPending->next = pending->next;
93       else
94         firstPendingConection = pending->next;
95       switch(pending->datatype) {
96         case SSLData_Client: {
97             struct Client *cptr = (struct Client *) pending->data;
98             if(success) {
99               if(FlagHas(&connection->flags, SSLFLAG_INCOMING))
100                 start_auth(cptr);
101               else if(!completed_connection(cptr))
102                 exit_client_msg(cptr, cptr, &me, "Registration failed.");
103             } else
104               exit_client_msg(cptr, cptr, &me, "SSL Handshake failed.");
105           }
106           break;
107       }
108       free(pending);
109     }
110     lastPending = pending;
111   }
112 }
113
114 static int ssl_handshake_outgoing(struct SSLConnection *connection) {
115   int ret = SSL_do_handshake(connection->session);
116   FlagClr(&connection->flags, SSLFLAG_HANDSHAKE_R);
117   FlagClr(&connection->flags, SSLFLAG_HANDSHAKE_W);
118   
119         switch(SSL_get_error(connection->session, ret)) {
120                 case SSL_ERROR_NONE:
121                         FlagClr(&connection->flags, SSLFLAG_HANDSHAKE);
122       FlagSet(&connection->flags, SSLFLAG_READY);
123       
124       ssl_handshake_completed(connection, 1);
125                         break;
126                 case SSL_ERROR_WANT_READ:
127                         FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_R);
128                         break;
129                 case SSL_ERROR_WANT_WRITE:
130                         FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_W);
131                         break;
132                 default:
133                         
134                         break;
135         }
136 }
137
138 struct SSLConnection *ssl_create_connect(int fd, void *data, enum SSLDataType datatype) {
139   struct SSLOutConnection *connection = malloc(sizeof(*connection));
140   struct SSLConnection *sslconn = (struct SSLConnection *)connection;
141   struct SSLPendingConections *pending = NULL;
142   
143   if(!connection)
144     return NULL;
145   
146   if(!ssl_is_initialized)
147     ssl_init();
148   
149   connection->context = SSL_CTX_new(SSLv23_client_method());
150         if(!connection->context) {
151                 goto ssl_create_connect_failed;
152         }
153         connection->session = SSL_new(connection->context);
154         if(!connection->session) {
155                 goto ssl_create_connect_failed;
156         }
157         if(!SSL_set_fd(connection->session, fd)) {
158                 goto ssl_create_connect_failed;
159         }
160         SSL_set_connect_state(connection->session);
161         FlagSet(&connection->flags, SSLFLAG_OUTGOING);
162   FlagSet(&connection->flags, SSLFLAG_HANDSHAKE);
163   
164   pending = malloc(sizeof(*pending));
165   if(!pending) {
166     goto ssl_create_connect_failed;
167   }
168   pending->connection = connection;
169   pending->next = firstPendingConection;
170   firstPendingConection = pending;
171   
172   pending->data = data;
173   pending->datatype = datatype;
174   
175         return sslconn;
176 ssl_create_connect_failed:
177         free(connection);
178         return NULL;
179 }
180
181 void ssl_start_handshake_connect(struct SSLConnection *connection) {
182   ssl_handshake_outgoing(connection);
183 }
184
185 struct SSLListener *ssl_create_listener() {
186   if(!ssl_is_initialized)
187     ssl_init();
188   
189   struct SSLListener *listener = calloc(1, sizeof(*listener));
190   listener->context = SSL_CTX_new(SSLv23_server_method());
191   if(!listener->context) {
192     goto ssl_create_listener_failed;
193   }
194   
195   char *certfile = conf_get_local()->sslcertfile;
196   char *keyfile = conf_get_local()->sslkeyfile;
197   char *cafile = conf_get_local()->sslcafile;
198   
199   if(!certfile) {
200     goto ssl_create_listener_failed;
201   }
202   if(!keyfile) {
203     keyfile = certfile;
204   }
205   
206   /* load certificate */
207   if(SSL_CTX_use_certificate_file(listener->context, certfile, SSL_FILETYPE_PEM) <= 0) {
208     goto ssl_create_listener_failed;
209   }
210   /* load keyfile */
211   if(SSL_CTX_use_PrivateKey_file(listener->context, keyfile, SSL_FILETYPE_PEM) <= 0) {
212     goto ssl_create_listener_failed;
213   }
214   /* check certificate and keyfile */
215   if(!SSL_CTX_check_private_key(listener->context)) {
216     goto ssl_create_listener_failed;
217   }
218   /* load cafile */
219   if(cafile && cafile[0] && SSL_CTX_load_verify_locations(listener->context, cafile, NULL) <= 0) {
220     goto ssl_create_listener_failed;
221   }
222   FlagSet(&listener->flags, SSLFLAG_READY);
223   return listener;
224 ssl_create_listener_failed:
225   free(listener);
226   return NULL;
227 }
228
229 static int ssl_handshake_incoming(struct SSLConnection *connection) {
230   int result = SSL_accept(connection->session);
231         FlagClr(&connection->flags, SSLFLAG_HANDSHAKE_R);
232   FlagClr(&connection->flags, SSLFLAG_HANDSHAKE_W);
233         switch(SSL_get_error(connection->session, result)) {
234                 case SSL_ERROR_NONE:
235                         FlagClr(&connection->flags, SSLFLAG_HANDSHAKE);
236       FlagSet(&connection->flags, SSLFLAG_READY);
237       
238       ssl_handshake_completed(connection, 1);
239                         return 0;
240                 case SSL_ERROR_WANT_READ:
241                         FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_R);
242       return 1;
243                 case SSL_ERROR_WANT_WRITE:
244       FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_W);
245                         return 1;
246                 default:
247       //unset connection! 
248       //Handshake error!
249       ssl_handshake_completed(connection, 0);
250                         return 0;
251         }
252   return 0;
253 }
254
255 struct SSLConnection *ssl_start_handshake_listener(struct SSLListener *listener, int fd, void *data, enum SSLDataType datatype) {
256   if(!listener)
257     return NULL;
258   struct SSLPendingConections *pending = NULL;
259   struct SSLConnection *connection = malloc(sizeof(*connection));
260   connection->session = SSL_new(listener->context);
261   if(!connection->session) {
262     goto ssl_start_handshake_listener_failed;
263   }
264   if(!SSL_set_fd(connection->session, fd)) {
265     goto ssl_start_handshake_listener_failed;
266   }
267   FlagSet(&connection->flags, SSLFLAG_INCOMING);
268   FlagSet(&connection->flags, SSLFLAG_HANDSHAKE);
269   FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_R);
270   
271   pending = malloc(sizeof(*pending));
272   if(!pending) {
273     goto ssl_start_handshake_listener_failed;
274   }
275   pending->connection = connection;
276   pending->next = firstPendingConection;
277   firstPendingConection = pending;
278   
279   pending->data = data;
280   pending->datatype = datatype;
281   
282   ssl_handshake_incoming(connection);
283         return connection;
284 ssl_start_handshake_listener_failed:
285         free(connection);
286         return NULL;
287 }
288
289 IOResult ssl_recv_decrypt(struct SSLConnection *connection, char *buf, unsigned int buflen, unsigned int *len) {
290   if(FlagHas(&connection->flags, SSLFLAG_HANDSHAKE)) {
291     if(FlagHas(&connection->flags, SSLFLAG_INCOMING)) {
292       ssl_handshake_incoming(connection);
293       return IO_BLOCKED;
294     }
295     if(FlagHas(&connection->flags, SSLFLAG_OUTGOING)) {
296       ssl_handshake_outgoing(connection);
297       return IO_BLOCKED;
298     }
299   }
300   
301   *len = SSL_read(connection->session, buf, buflen);
302   FlagClr(&connection->flags, SSLFLAG_HANDSHAKE_R);
303   int err = SSL_get_error(connection->session, *len);
304   switch(err) {
305     case SSL_ERROR_NONE:
306         return IO_SUCCESS;
307                 case SSL_ERROR_ZERO_RETURN:
308                   return IO_FAILURE;
309                 case SSL_ERROR_WANT_READ:
310       FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_R);
311                         return IO_BLOCKED;
312                 case SSL_ERROR_WANT_WRITE:
313                         FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_W);
314                         return IO_BLOCKED;
315                 case SSL_ERROR_SYSCALL:
316                         return IO_FAILURE;
317                 default:
318       return IO_FAILURE;
319   }
320 }
321
322 static ssize_t ssl_writev(SSL *ssl, const struct iovec *vector, int count) {
323   char *buffer;
324   register char *bp;
325   size_t bytes, to_copy;
326   int i;
327
328   /* Find the total number of bytes to be written.  */
329   bytes = 0;
330   for (i = 0; i < count; ++i)
331     bytes += vector[i].iov_len;
332
333   /* Allocate a temporary buffer to hold the data.  */
334   buffer = (char *) alloca (bytes);
335
336   /* Copy the data into BUFFER.  */
337   to_copy = bytes;
338   bp = buffer;
339   for (i = 0; i < count; ++i) {
340     size_t copy = ((vector[i].iov_len) > (to_copy) ? (to_copy) : (vector[i].iov_len));
341     memcpy ((void *) bp, (void *) vector[i].iov_base, copy);
342     bp += copy;
343     to_copy -= copy;
344     if (to_copy == 0)
345       break;
346   }
347   return SSL_write(ssl, buffer, bytes);
348 }
349
350 IOResult ssl_send_encrypt_plain(struct SSLConnection *connection, char* buf, int len) {
351   return SSL_write(connection->session, buf, len);
352 }
353
354 IOResult ssl_send_encrypt(struct SSLConnection *connection, struct MsgQ* buf, unsigned int *count_in, unsigned int *count_out) {
355   int res;
356   int count;
357   struct iovec iov[IOV_MAX];
358
359   assert(0 != buf);
360   assert(0 != count_in);
361   assert(0 != count_out);
362
363   *count_in = 0;
364   count = msgq_mapiov(buf, iov, IOV_MAX, count_in);
365   res = ssl_writev(connection->session, iov, count);
366
367   switch(SSL_get_error(connection->session, res)) {
368     case SSL_ERROR_NONE:
369                 case SSL_ERROR_ZERO_RETURN:
370       *count_out = (unsigned) res;
371       return IO_SUCCESS;
372     case SSL_ERROR_WANT_READ:
373       FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_R);
374                         return IO_BLOCKED;
375     case SSL_ERROR_WANT_WRITE:
376       FlagSet(&connection->flags, SSLFLAG_HANDSHAKE_W);
377                         return IO_BLOCKED;
378     default:
379       *count_out = 0;
380       return IO_FAILURE;
381   }
382 }
383
384 int ssl_connection_flush(struct SSLConnection *connection) {
385   if(connection) {
386     if(ssl_handshake(connection)) {
387       if(FlagHas(&connection->flags, SSLFLAG_INCOMING)) {
388         return ssl_handshake_incoming(connection);
389       }
390       if(FlagHas(&connection->flags, SSLFLAG_OUTGOING)) {
391         return ssl_handshake_outgoing(connection);
392       }
393     }
394   } else {
395     struct SSLPendingConections *curr, *last = NULL, *next;
396     for(curr = firstPendingConection; curr; curr = next) {
397       next = curr->next;
398       if(!ssl_connection_flush(curr->connection)) {
399         // connection is already in auth process here, curr is freed!
400         continue;
401       }
402       last = curr;
403     }
404   }
405   return 0;
406 }
407
408 #else
409 void ssl_free_connection(struct SSLConnection *connection) {}
410 void ssl_free_listener(struct SSLConnection *listener) {}
411 struct SSLListener *ssl_create_listener() { return NULL; }
412 struct SSLConnection *ssl_start_handshake_listener(struct SSLListener *listener, int fd, void *data, enum SSLDataType datatype) { return NULL; }
413 IOResult ssl_recv_decrypt(struct SSLConnection *connection, char *buf, int *len) { return IO_FAILURE; }
414 int ssl_connection_flush(struct SSLConnection *connection) { return 0; };
415 #endif
416