fix possible crash on user deletion
[srvx.git] / src / dict-splay.c
1 /* dict-splay.c - Abstract dictionary type
2  * Copyright 2000-2004 srvx Development Team
3  *
4  * This file is part of srvx.
5  *
6  * srvx is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  */
16
17 #include "common.h"
18 #include "dict.h"
19
20 /*
21  *    Create new dictionary.
22  */
23 dict_t
24 dict_new(void)
25 {
26     dict_t dict = calloc(1, sizeof(*dict));
27     return dict;
28 }
29
30 /*
31  *    Return number of entries in the dictionary.
32  */
33 unsigned int
34 dict_size(dict_t dict)
35 {
36     return dict->count;
37 }
38
39 /*
40  *    Set the function to be called when freeing a key structure.
41  *    If the function is NULL, just forget about the pointer.
42  */
43 void
44 dict_set_free_keys(dict_t dict, free_f free_keys)
45 {
46     dict->free_keys = free_keys;
47 }
48
49 /*
50  *    Set the function to free data.
51  * If the function is NULL, just forget about the pointer.
52  */
53 void
54 dict_set_free_data(dict_t dict, free_f free_data)
55 {
56     dict->free_data = free_data;
57 }
58
59 const char *
60 dict_foreach(dict_t dict, dict_iterator_f it_f, void *extra)
61 {
62     dict_iterator_t it;
63
64     for (it=dict_first(dict); it; it=iter_next(it)) {
65         if (it_f(iter_key(it), iter_data(it), extra)) return iter_key(it);
66     }
67     return NULL;
68 }
69
70 /*
71  *   This function finds a node and pulls it to the top of the tree.
72  *   This helps balance the tree and auto-cache things you search for.
73  */
74 static struct dict_node*
75 dict_splay(struct dict_node *node, const char *key)
76 {
77     struct dict_node N, *l, *r, *y;
78     int res;
79
80     if (!node) return NULL;
81     N.l = N.r = NULL;
82     l = r = &N;
83
84     while (1) {
85         verify(node);
86         res = irccasecmp(key, node->key);
87         if (!res) break;
88         if (res < 0) {
89             if (!node->l) break;
90             res = irccasecmp(key, node->l->key);
91             if (res < 0) {
92                 y = node->l;
93                 node->l = y->r;
94                 y->r = node;
95                 node = y;
96                 if (!node->l) break;
97             }
98             r->l = node;
99             r = node;
100             node = node->l;
101         } else { /* res > 0 */
102             if (!node->r) break;
103             res = irccasecmp(key, node->r->key);
104             if (res > 0) {
105                 y = node->r;
106                 node->r = y->l;
107                 y->l = node;
108                 node = y;
109                 if (!node->r) break;
110             }
111             l->r = node;
112             l = node;
113             node = node->r;
114         }
115     }
116     l->r = node->l;
117     r->l = node->r;
118     node->l = N.r;
119     node->r = N.l;
120     return node;
121 }
122
123 /*
124  *    Free node.  Free data/key using free_f functions.
125  */
126 static void
127 dict_dispose_node(struct dict_node *node, free_f free_keys, free_f free_data)
128 {
129     if (free_keys && node->key) {
130         if (free_keys == free)
131             free((void*)node->key);
132         else
133             free_keys((void*)node->key);
134     }
135     if (free_data && node->data) {
136         if (free_data == free)
137             free(node->data);
138         else
139             free_data(node->data);
140     }
141     free(node);
142 }
143
144 /*
145  *    Insert an entry into the dictionary.
146  *    Key ordering (and uniqueness) is determined by case-insensitive
147  *    string comparison.
148  */
149 void
150 dict_insert(dict_t dict, const char *key, void *data)
151 {
152     struct dict_node *new_node;
153     if (!key)
154         return;
155     verify(dict);
156     new_node = malloc(sizeof(struct dict_node));
157     new_node->key = key;
158     new_node->data = data;
159     if (dict->root) {
160         int res;
161         dict->root = dict_splay(dict->root, key);
162         res = irccasecmp(key, dict->root->key);
163         if (res < 0) {
164             /* insert just "before" current root */
165             new_node->l = dict->root->l;
166             new_node->r = dict->root;
167             dict->root->l = NULL;
168             if (dict->root->prev) {
169                 dict->root->prev->next = new_node;
170             } else {
171                 dict->first = new_node;
172             }
173             new_node->prev = dict->root->prev;
174             new_node->next = dict->root;
175             dict->root->prev = new_node;
176             dict->root = new_node;
177         } else if (res > 0) {
178             /* insert just "after" current root */
179             new_node->r = dict->root->r;
180             new_node->l = dict->root;
181             dict->root->r = NULL;
182             if (dict->root->next) {
183                 dict->root->next->prev = new_node;
184             } else {
185                 dict->last = new_node;
186             }
187             new_node->next = dict->root->next;
188             new_node->prev = dict->root;
189             dict->root->next = new_node;
190             dict->root = new_node;
191         } else {
192             /* maybe we don't want to overwrite it .. oh well */
193             if (dict->free_data) {
194                 if (dict->free_data == free)
195                     free(dict->root->data);
196                 else
197                     dict->free_data(dict->root->data);
198             }
199             if (dict->free_keys) {
200                 if (dict->free_keys == free)
201                     free((void*)dict->root->key);
202                 else
203                     dict->free_keys((void*)dict->root->key);
204             }
205             free(new_node);
206             dict->root->key = key;
207             dict->root->data = data;
208             /* decrement the count since we dropped the node */
209             dict->count--;
210         }
211     } else {
212         new_node->l = new_node->r = NULL;
213         new_node->next = new_node->prev = NULL;
214         dict->root = dict->first = dict->last = new_node;
215     }
216     dict->count++;
217 }
218
219 /*
220  *    Remove an entry from the dictionary.
221  *    Return non-zero if it was found, or zero if the key was not in the
222  *    dictionary.
223  */
224 int
225 dict_remove2(dict_t dict, const char *key, int no_dispose)
226 {
227     struct dict_node *new_root, *old_root;
228
229     if (!dict->root)
230         return 0;
231     verify(dict);
232     dict->root = dict_splay(dict->root, key);
233     if (irccasecmp(key, dict->root->key))
234         return 0;
235
236     if (!dict->root->l) {
237         new_root = dict->root->r;
238     } else {
239         new_root = dict_splay(dict->root->l, key);
240         new_root->r = dict->root->r;
241     }
242     if (dict->root->prev) dict->root->prev->next = dict->root->next;
243     if (dict->first == dict->root) dict->first = dict->first->next;
244     if (dict->root->next) dict->root->next->prev = dict->root->prev;
245     if (dict->last == dict->root) dict->last = dict->last->prev;
246     old_root = dict->root;
247     dict->root = new_root;
248     dict->count--;
249     if (no_dispose) {
250         free(old_root);
251     } else {
252         dict_dispose_node(old_root, dict->free_keys, dict->free_data);
253     }
254     return 1;
255 }
256
257 /*
258  *    Find an entry in the dictionary.
259  *    If "found" is non-NULL, set it to non-zero if the key was found.
260  *    Return the data associated with the key (or NULL if the key was
261  *    not found).
262  */
263 void*
264 dict_find(dict_t dict, const char *key, int *found)
265 {
266     int was_found;
267     if (!dict || !dict->root || !key) {
268         if (found)
269             *found = 0;
270         return NULL;
271     }
272     verify(dict);
273     dict->root = dict_splay(dict->root, key);
274     was_found = !irccasecmp(key, dict->root->key);
275     if (found)
276         *found = was_found;
277     return was_found ? dict->root->data : NULL;
278 }
279
280 /*
281  *    Delete an entire dictionary.
282  */
283 void
284 dict_delete(dict_t dict)
285 {
286     dict_iterator_t it, next;
287     if (!dict)
288         return;
289     verify(dict);
290     for (it=dict_first(dict); it; it=next) {
291         next = iter_next(it);
292         dict_dispose_node(it, dict->free_keys, dict->free_data);
293     }
294     free(dict);
295 }
296
297 struct dict_sanity_struct {
298     unsigned int node_count;
299     struct dict_node *bad_node;
300     char error[128];
301 };
302
303 static int
304 dict_sanity_check_node(struct dict_node *node, struct dict_sanity_struct *dss)
305 {
306     verify(node);
307     if (!node->key) {
308         snprintf(dss->error, sizeof(dss->error), "Node %p had null key", (void*)node);
309         return 1;
310     }
311     if (node->l) {
312         if (dict_sanity_check_node(node->l, dss)) return 1;
313         if (irccasecmp(node->l->key, node->key) >= 0) {
314             snprintf(dss->error, sizeof(dss->error), "Node %p's left child's key '%s' >= its key '%s'", (void*)node, node->l->key, node->key);
315             return 1;
316         }
317     }
318     if (node->r) {
319         if (dict_sanity_check_node(node->r, dss)) return 1;
320         if (irccasecmp(node->key, node->r->key) >= 0) {
321             snprintf(dss->error, sizeof(dss->error), "Node %p's right child's key '%s' <= its key '%s'", (void*)node, node->r->key, node->key);
322             return 1;
323         }
324     }
325     dss->node_count++;
326     return 0;
327 }
328
329 /*
330  *    Perform sanity checks on the dict's internal structure.
331  */
332 char *
333 dict_sanity_check(dict_t dict)
334 {
335     struct dict_sanity_struct dss;
336     dss.node_count = 0;
337     dss.bad_node = 0;
338     dss.error[0] = 0;
339     verify(dict);
340     if (dict->root && dict_sanity_check_node(dict->root, &dss)) {
341         return strdup(dss.error);
342     } else if (dss.node_count != dict->count) {
343         snprintf(dss.error, sizeof(dss.error), "Counted %d nodes but expected %d.", dss.node_count, dict->count);
344         return strdup(dss.error);
345     } else {
346         return 0;
347     }
348 }