License update
[srvx.git] / src / ioset.c
1 /* ioset.h - srvx event loop
2  * Copyright 2002-2004 srvx Development Team
3  *
4  * This file is part of srvx.
5  *
6  * srvx is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with srvx; if not, write to the Free Software Foundation,
18  * Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA.
19  */
20
21 #include "ioset.h"
22 #include "log.h"
23 #include "timeq.h"
24 #include "saxdb.h"
25 #include "conf.h"
26
27 #ifdef HAVE_FCNTL_H
28 #include <fcntl.h>
29 #endif
30 #ifdef HAVE_SYS_SELECT_H
31 #include <sys/select.h>
32 #endif
33 #ifdef HAVE_SYS_SOCKET_H
34 #include <sys/socket.h>
35 #endif
36
37 #ifndef IOSET_DEBUG
38 #define IOSET_DEBUG 0
39 #endif
40
41 #define IS_EOL(CH) ((CH) == '\n')
42
43 extern int uplink_connect(void);
44 static int clock_skew;
45 int do_write_dbs;
46 int do_reopen;
47
48 static struct io_fd **fds;
49 static unsigned int fds_size;
50 static fd_set read_fds, write_fds;
51
52 static void
53 ioq_init(struct ioq *ioq, int size) {
54     ioq->buf = malloc(size);
55     ioq->get = ioq->put = 0;
56     ioq->size = size;
57 }
58
59 static unsigned int
60 ioq_put_avail(const struct ioq *ioq) {
61     /* Subtract 1 from ioq->get to be sure we don't fill the buffer
62      * and make it look empty even when there's data in it. */
63     if (ioq->put < ioq->get)
64         return ioq->get - ioq->put - 1;
65     else if (ioq->get == 0)
66         return ioq->size - ioq->put - 1;
67     else
68         return ioq->size - ioq->put;
69 }
70
71 static unsigned int
72 ioq_get_avail(const struct ioq *ioq) {
73     return ((ioq->put < ioq->get) ? ioq->size : ioq->put) - ioq->get;
74 }
75
76 static unsigned int
77 ioq_used(const struct ioq *ioq) {
78     return ((ioq->put < ioq->get) ? ioq->size : 0) + ioq->put - ioq->get;
79 }
80
81 static unsigned int
82 ioq_grow(struct ioq *ioq) {
83     int new_size = ioq->size << 1;
84     char *new_buf = malloc(new_size);
85     int get_avail = ioq_get_avail(ioq);
86     memcpy(new_buf, ioq->buf + ioq->get, get_avail);
87     if (ioq->put < ioq->get)
88         memcpy(new_buf + get_avail, ioq->buf, ioq->put);
89     free(ioq->buf);
90     ioq->put = ioq_used(ioq);
91     ioq->get = 0;
92     ioq->buf = new_buf;
93     ioq->size = new_size;
94     return new_size - ioq->put;
95 }
96
97 void
98 ioset_cleanup(void) {
99     free(fds);
100 }
101
102 struct io_fd *
103 ioset_add(int fd) {
104     struct io_fd *res;
105     int flags;
106
107     if (fd < 0) {
108         log_module(MAIN_LOG, LOG_ERROR, "Somebody called ioset_add(%d) on a negative fd!", fd);
109         return 0;
110     }
111     res = calloc(1, sizeof(*res));
112     if (!res)
113         return 0;
114     res->fd = fd;
115     ioq_init(&res->send, 1024);
116     ioq_init(&res->recv, 1024);
117     if ((unsigned)fd >= fds_size) {
118         unsigned int old_size = fds_size;
119         fds_size = fd + 8;
120         fds = realloc(fds, fds_size*sizeof(*fds));
121         memset(fds+old_size, 0, (fds_size-old_size)*sizeof(*fds));
122     }
123     fds[fd] = res;
124     flags = fcntl(fd, F_GETFL);
125     fcntl(fd, F_SETFL, flags|O_NONBLOCK);
126     return res;
127 }
128
129 struct io_fd *
130 ioset_connect(struct sockaddr *local, unsigned int sa_size, const char *peer, unsigned int port, int blocking, void *data, void (*connect_cb)(struct io_fd *fd, int error)) {
131     int fd, res;
132     struct io_fd *io_fd;
133     struct sockaddr_in sin;
134     unsigned long ip;
135
136     if (!getipbyname(peer, &ip)) {
137         log_module(MAIN_LOG, LOG_ERROR, "getipbyname(%s) failed.", peer);
138         return NULL;
139     }
140     sin.sin_addr.s_addr = ip;
141     if (local) {
142         if ((fd = socket(local->sa_family, SOCK_STREAM, 0)) < 0) {
143             log_module(MAIN_LOG, LOG_ERROR, "socket() for %s returned errno %d (%s)", peer, errno, strerror(errno));
144             return NULL;
145         }
146         if (bind(fd, local, sa_size) < 0) {
147             log_module(MAIN_LOG, LOG_ERROR, "bind() of socket for %s (fd %d) returned errno %d (%s).  Will let operating system choose.", peer, fd, errno, strerror(errno));
148         }
149     } else {
150         if ((fd = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
151             log_module(MAIN_LOG, LOG_ERROR, "socket() for %s returned errno %d (%s).", peer, errno, strerror(errno));
152             return NULL;
153         }
154     }
155     sin.sin_family = AF_INET;
156     sin.sin_port = htons(port);
157     if (blocking) {
158         res = connect(fd, (struct sockaddr*)&sin, sizeof(sin));
159         io_fd = ioset_add(fd);
160     } else {
161         io_fd = ioset_add(fd);
162         res = connect(fd, (struct sockaddr*)&sin, sizeof(sin));
163     }
164     if (!io_fd) {
165         close(fd);
166         return NULL;
167     }
168     io_fd->data = data;
169     io_fd->connect_cb = connect_cb;
170     if (res < 0) {
171         switch (errno) {
172         case EINPROGRESS: /* only if !blocking */
173             return io_fd;
174         default:
175             log_module(MAIN_LOG, LOG_ERROR, "connect(%s:%d) (fd %d) returned errno %d (%s).", peer, port, io_fd->fd, errno, strerror(errno));
176             /* then fall through */
177         case EHOSTUNREACH:
178         case ECONNREFUSED:
179             ioset_close(io_fd->fd, 1);
180             return NULL;
181         }
182     }
183     if (connect_cb)
184         connect_cb(io_fd, ((res < 0) ? errno : 0));
185     return io_fd;
186 }
187
188 static void
189 ioset_try_write(struct io_fd *fd) {
190     int res;
191     unsigned int req = ioq_get_avail(&fd->send);
192     res = write(fd->fd, fd->send.buf+fd->send.get, req);
193     if (res < 0) {
194         switch (errno) {
195         case EAGAIN:
196             break;
197         default:
198             log_module(MAIN_LOG, LOG_ERROR, "write() on fd %d error %d: %s", fd->fd, errno, strerror(errno));
199         }
200     } else {
201         fd->send.get += res;
202         if (fd->send.get == fd->send.size)
203             fd->send.get = 0;
204     }
205 }
206
207 void
208 ioset_close(int fd, int os_close) {
209     struct io_fd *fdp;
210     if (!(fdp = fds[fd]))
211         return;
212     fds[fd] = NULL;
213     if (fdp->destroy_cb)
214         fdp->destroy_cb(fdp);
215     if (fdp->send.get != fdp->send.put) {
216         int flags = fcntl(fd, F_GETFL);
217         fcntl(fd, F_SETFL, flags&~O_NONBLOCK);
218         ioset_try_write(fdp);
219         /* it may need to send the beginning of the buffer now.. */
220         if (fdp->send.get != fdp->send.put)
221             ioset_try_write(fdp);
222     }
223     free(fdp->send.buf);
224     free(fdp->recv.buf);
225     if (os_close)
226         close(fd);
227     free(fdp);
228     FD_CLR(fd, &read_fds);
229     FD_CLR(fd, &write_fds);
230 }
231
232 static int
233 ioset_find_line_length(struct io_fd *fd) {
234     unsigned int pos, max, len;
235     len = 0;
236     max = (fd->recv.put < fd->recv.get) ? fd->recv.size : fd->recv.put;
237     for (pos = fd->recv.get; pos < max; ++pos, ++len)
238         if (IS_EOL(fd->recv.buf[pos]))
239             return fd->line_len = len + 1;
240     if (fd->recv.put < fd->recv.get)
241         for (pos = 0; pos < fd->recv.put; ++pos, ++len)
242             if (IS_EOL(fd->recv.buf[pos]))
243                 return fd->line_len = len + 1;
244     return fd->line_len = 0;
245 }
246
247 static void
248 ioset_buffered_read(struct io_fd *fd) {
249     int put_avail, nbr, fdnum;
250     
251     if (!(put_avail = ioq_put_avail(&fd->recv)))
252         put_avail = ioq_grow(&fd->recv);
253     nbr = read(fd->fd, fd->recv.buf + fd->recv.put, put_avail);
254     if (nbr < 0) {
255         switch (errno) {
256         case EAGAIN:
257             break;
258         default:
259             log_module(MAIN_LOG, LOG_ERROR, "Unexpected read() error %d on fd %d: %s", errno, fd->fd, strerror(errno));
260             /* Just flag it as EOF and call readable_cb() to notify the fd's owner. */
261             fd->eof = 1;
262             fd->wants_reads = 0;
263             fd->readable_cb(fd);
264         }
265     } else if (nbr == 0) {
266         fd->eof = 1;
267         fd->wants_reads = 0;
268         fd->readable_cb(fd);
269     } else {
270         if (fd->line_len == 0) {
271             unsigned int pos;
272             for (pos = fd->recv.put; pos < fd->recv.put + nbr; ++pos) {
273                 if (IS_EOL(fd->recv.buf[pos])) {
274                     if (fd->recv.put < fd->recv.get)
275                         fd->line_len = fd->recv.size + pos + 1 - fd->recv.get;
276                     else
277                         fd->line_len = pos + 1 - fd->recv.get;
278                     break;
279                 }
280             }
281         }
282         fd->recv.put += nbr;
283         if (fd->recv.put == fd->recv.size)
284             fd->recv.put = 0;
285         fdnum = fd->fd;
286         while (fd->wants_reads && (fd->line_len > 0)) {
287             fd->readable_cb(fd);
288             if (!fds[fdnum])
289                 break; /* make sure they didn't close on us */
290             ioset_find_line_length(fd);
291         }
292     }
293 }
294
295 int
296 ioset_line_read(struct io_fd *fd, char *dest, int max) {
297     int avail, done;
298     if (fd->eof && (!ioq_get_avail(&fd->recv) ||  (fd->line_len < 0)))
299         return 0;
300     if (fd->line_len < 0)
301         return -1;
302     if (fd->line_len < max)
303         max = fd->line_len;
304     avail = ioq_get_avail(&fd->recv);
305     if (max > avail) {
306         memcpy(dest, fd->recv.buf + fd->recv.get, avail);
307         fd->recv.get += avail;
308         assert(fd->recv.get == fd->recv.size);
309         fd->recv.get = 0;
310         done = avail;
311     } else {
312         done = 0;
313     }
314     memcpy(dest + done, fd->recv.buf + fd->recv.get, max - done);
315     fd->recv.get += max - done;
316     if (fd->recv.get == fd->recv.size)
317         fd->recv.get = 0;
318     dest[max] = 0;
319     ioset_find_line_length(fd);
320     return max;
321 }
322
323 #if 1
324 #define debug_fdsets(MSG, NFDS, READ_FDS, WRITE_FDS, EXCEPT_FDS, SELECT_TIMEOUT) (void)0
325 #else
326 static void
327 debug_fdsets(const char *msg, int nfds, fd_set *read_fds, fd_set *write_fds, fd_set *except_fds, struct timeval *select_timeout) {
328     static const char *flag_text[8] = { "---", "r", "w", "rw", "e", "er", "ew", "erw" };
329     char buf[MAXLEN];
330     int pos, ii, flags;
331     struct timeval now;
332
333     for (pos=ii=0; ii<nfds; ++ii) {
334         flags  = (read_fds && FD_ISSET(ii, read_fds)) ? 1 : 0;
335         flags |= (write_fds && FD_ISSET(ii, write_fds)) ? 2 : 0;
336         flags |= (except_fds && FD_ISSET(ii, except_fds)) ? 4 : 0;
337         if (!flags)
338             continue;
339         pos += sprintf(buf+pos, " %d%s", ii, flag_text[flags]);
340     }
341     gettimeofday(&now, NULL);
342     if (select_timeout) {
343         log_module(MAIN_LOG, LOG_DEBUG, "%s, at "FMT_TIME_T".%06ld:%s (timeout "FMT_TIME_T".%06ld)", msg, now.tv_sec, now.tv_usec, buf, select_timeout->tv_sec, select_timeout->tv_usec);
344     } else {
345         log_module(MAIN_LOG, LOG_DEBUG, "%s, at "FMT_TIME_T".%06ld:%s (no timeout)", msg, now.tv_sec, now.tv_usec, buf);
346     }
347 }
348 #endif
349
350 void
351 ioset_run(void) {
352     extern struct io_fd *socket_io_fd;
353     struct timeval select_timeout;
354     unsigned int nn;
355     int select_result, max_fd;
356     time_t wakey;
357     struct io_fd *fd;
358
359     while (!quit_services) {
360         while (!socket_io_fd)
361             uplink_connect();
362
363         /* How long to sleep? (fill in select_timeout) */
364         wakey = timeq_next();
365         if ((wakey - now) < 0)
366             select_timeout.tv_sec = 0;
367         else
368             select_timeout.tv_sec = wakey - now;
369         select_timeout.tv_usec = 0;
370
371         /* Set up read_fds and write_fds fdsets. */
372         FD_ZERO(&read_fds);
373         FD_ZERO(&write_fds);
374         max_fd = 0;
375         for (nn=0; nn<fds_size; nn++) {
376             if (!(fd = fds[nn]))
377                 continue;
378             max_fd = nn;
379             if (fd->wants_reads)
380                 FD_SET(nn, &read_fds);
381             if ((fd->send.get != fd->send.put) || !fd->connected)
382                 FD_SET(nn, &write_fds);
383         }
384
385         /* Check for activity, update time. */
386         debug_fdsets("Entering select", max_fd+1, &read_fds, &write_fds, NULL, &select_timeout);
387         select_result = select(max_fd + 1, &read_fds, &write_fds, NULL, &select_timeout);
388         debug_fdsets("After select", max_fd+1, &read_fds, &write_fds, NULL, &select_timeout);
389         now = time(NULL) + clock_skew;
390         if (select_result < 0) {
391             if (errno != EINTR) {
392                 log_module(MAIN_LOG, LOG_ERROR, "select() error %d: %s", errno, strerror(errno));
393                 close_socket();
394             }
395             continue;
396         }
397
398         /* Call back anybody that has connect or read activity and wants to know. */
399         for (nn=0; nn<fds_size; nn++) {
400             if (!(fd = fds[nn]))
401                 continue;
402             if (FD_ISSET(nn, &read_fds)) {
403                 if (fd->line_reads)
404                     ioset_buffered_read(fd);
405                 else
406                     fd->readable_cb(fd);
407             }
408             if (FD_ISSET(nn, &write_fds) && !fd->connected) {
409                 int rc, arglen = sizeof(rc);
410                 if (getsockopt(fd->fd, SOL_SOCKET, SO_ERROR, &rc, &arglen) < 0)
411                     rc = errno;
412                 fd->connected = 1;
413                 if (fd->connect_cb)
414                     fd->connect_cb(fd, rc);
415             }
416             /* Note: check whether write FD is still set, since the
417              * connect_cb() might close the FD, making us dereference
418              * a free()'d pointer for the fd.
419              */
420             if (FD_ISSET(nn, &write_fds) && (fd->send.get != fd->send.put))
421                 ioset_try_write(fd);
422         }
423
424         /* Call any timeq events we need to call. */
425         timeq_run();
426         if (do_write_dbs) {
427             saxdb_write_all();
428             do_write_dbs = 0;
429         }
430         if (do_reopen) {
431             extern char *services_config;
432             conf_read(services_config);
433             do_reopen = 0;
434         }
435     }
436 }
437
438 void
439 ioset_write(struct io_fd *fd, const char *buf, unsigned int nbw) {
440     unsigned int avail;
441     while (ioq_used(&fd->send) + nbw >= fd->send.size)
442         ioq_grow(&fd->send);
443     avail = ioq_put_avail(&fd->send);
444     if (nbw > avail) {
445         memcpy(fd->send.buf + fd->send.put, buf, avail);
446         buf += avail;
447         nbw -= avail;
448         fd->send.put = 0;
449     }
450     memcpy(fd->send.buf + fd->send.put, buf, nbw);
451     fd->send.put += nbw;
452     if (fd->send.put == fd->send.size)
453         fd->send.put = 0;
454 }
455
456 void
457 ioset_set_time(unsigned long new_now) {
458     clock_skew = new_now - time(NULL);
459     now = new_now;
460 }