fix possible crash on user deletion
[srvx.git] / src / dict-splay.c
index 42bb9025778342b5107b190ee4f70cc0ac04b55c..2c8a282a77413c38fd9940a3cdaa487cef90f9a2 100644 (file)
@@ -75,40 +75,43 @@ static struct dict_node*
 dict_splay(struct dict_node *node, const char *key)
 {
     struct dict_node N, *l, *r, *y;
+    int res;
+
     if (!node) return NULL;
     N.l = N.r = NULL;
     l = r = &N;
 
     while (1) {
-       int res = irccasecmp(key, node->key);
-       if (!res) break;
-       if (res < 0) {
-           if (!node->l) break;
-           res = irccasecmp(key, node->l->key);
-           if (res < 0) {
-               y = node->l;
-               node->l = y->r;
-               y->r = node;
-               node = y;
-               if (!node->l) break;
-           }
-           r->l = node;
-           r = node;
-           node = node->l;
-       } else { /* res > 0 */
-           if (!node->r) break;
-           res = irccasecmp(key, node->r->key);
-           if (res > 0) {
-               y = node->r;
-               node->r = y->l;
-               y->l = node;
-               node = y;
-               if (!node->r) break;
-           }
-           l->r = node;
-           l = node;
-           node = node->r;
-       }
+        verify(node);
+        res = irccasecmp(key, node->key);
+        if (!res) break;
+        if (res < 0) {
+            if (!node->l) break;
+            res = irccasecmp(key, node->l->key);
+            if (res < 0) {
+                y = node->l;
+                node->l = y->r;
+                y->r = node;
+                node = y;
+                if (!node->l) break;
+            }
+            r->l = node;
+            r = node;
+            node = node->l;
+        } else { /* res > 0 */
+            if (!node->r) break;
+            res = irccasecmp(key, node->r->key);
+            if (res > 0) {
+                y = node->r;
+                node->r = y->l;
+                y->l = node;
+                node = y;
+                if (!node->r) break;
+            }
+            l->r = node;
+            l = node;
+            node = node->r;
+        }
     }
     l->r = node->l;
     r->l = node->r;
@@ -123,10 +126,18 @@ dict_splay(struct dict_node *node, const char *key)
 static void
 dict_dispose_node(struct dict_node *node, free_f free_keys, free_f free_data)
 {
-    if (free_keys && node->key)
-        free_keys((void*)node->key);
-    if (free_data && node->data)
-        free_data(node->data);
+    if (free_keys && node->key) {
+        if (free_keys == free)
+            free((void*)node->key);
+        else
+            free_keys((void*)node->key);
+    }
+    if (free_data && node->data) {
+        if (free_data == free)
+            free(node->data);
+        else
+            free_data(node->data);
+    }
     free(node);
 }
 
@@ -141,18 +152,19 @@ dict_insert(dict_t dict, const char *key, void *data)
     struct dict_node *new_node;
     if (!key)
         return;
+    verify(dict);
     new_node = malloc(sizeof(struct dict_node));
     new_node->key = key;
     new_node->data = data;
     if (dict->root) {
-       int res;
-       dict->root = dict_splay(dict->root, key);
-       res = irccasecmp(key, dict->root->key);
-       if (res < 0) {
+        int res;
+        dict->root = dict_splay(dict->root, key);
+        res = irccasecmp(key, dict->root->key);
+        if (res < 0) {
             /* insert just "before" current root */
-           new_node->l = dict->root->l;
-           new_node->r = dict->root;
-           dict->root->l = NULL;
+            new_node->l = dict->root->l;
+            new_node->r = dict->root;
+            dict->root->l = NULL;
             if (dict->root->prev) {
                 dict->root->prev->next = new_node;
             } else {
@@ -161,12 +173,12 @@ dict_insert(dict_t dict, const char *key, void *data)
             new_node->prev = dict->root->prev;
             new_node->next = dict->root;
             dict->root->prev = new_node;
-           dict->root = new_node;
-       } else if (res > 0) {
+            dict->root = new_node;
+        } else if (res > 0) {
             /* insert just "after" current root */
-           new_node->r = dict->root->r;
-           new_node->l = dict->root;
-           dict->root->r = NULL;
+            new_node->r = dict->root->r;
+            new_node->l = dict->root;
+            dict->root->r = NULL;
             if (dict->root->next) {
                 dict->root->next->prev = new_node;
             } else {
@@ -175,21 +187,31 @@ dict_insert(dict_t dict, const char *key, void *data)
             new_node->next = dict->root->next;
             new_node->prev = dict->root;
             dict->root->next = new_node;
-           dict->root = new_node;
-       } else {
-           /* maybe we don't want to overwrite it .. oh well */
-           if (dict->free_data) dict->free_data(dict->root->data);
-            if (dict->free_keys) dict->free_keys((void*)dict->root->key);
+            dict->root = new_node;
+        } else {
+            /* maybe we don't want to overwrite it .. oh well */
+            if (dict->free_data) {
+                if (dict->free_data == free)
+                    free(dict->root->data);
+                else
+                    dict->free_data(dict->root->data);
+            }
+            if (dict->free_keys) {
+                if (dict->free_keys == free)
+                    free((void*)dict->root->key);
+                else
+                    dict->free_keys((void*)dict->root->key);
+            }
             free(new_node);
             dict->root->key = key;
-           dict->root->data = data;
-           /* decrement the count since we dropped the node */
-           dict->count--;
-       }
+            dict->root->data = data;
+            /* decrement the count since we dropped the node */
+            dict->count--;
+        }
     } else {
-       new_node->l = new_node->r = NULL;
+        new_node->l = new_node->r = NULL;
         new_node->next = new_node->prev = NULL;
-       dict->root = dict->first = dict->last = new_node;
+        dict->root = dict->first = dict->last = new_node;
     }
     dict->count++;
 }
@@ -206,6 +228,7 @@ dict_remove2(dict_t dict, const char *key, int no_dispose)
 
     if (!dict->root)
         return 0;
+    verify(dict);
     dict->root = dict_splay(dict->root, key);
     if (irccasecmp(key, dict->root->key))
         return 0;
@@ -242,10 +265,11 @@ dict_find(dict_t dict, const char *key, int *found)
 {
     int was_found;
     if (!dict || !dict->root || !key) {
-       if (found)
+        if (found)
             *found = 0;
-       return NULL;
+        return NULL;
     }
+    verify(dict);
     dict->root = dict_splay(dict->root, key);
     was_found = !irccasecmp(key, dict->root->key);
     if (found)
@@ -262,6 +286,7 @@ dict_delete(dict_t dict)
     dict_iterator_t it, next;
     if (!dict)
         return;
+    verify(dict);
     for (it=dict_first(dict); it; it=next) {
         next = iter_next(it);
         dict_dispose_node(it, dict->free_keys, dict->free_data);
@@ -278,21 +303,22 @@ struct dict_sanity_struct {
 static int
 dict_sanity_check_node(struct dict_node *node, struct dict_sanity_struct *dss)
 {
+    verify(node);
     if (!node->key) {
-        snprintf(dss->error, sizeof(dss->error), "Node %p had null key", node);
+        snprintf(dss->error, sizeof(dss->error), "Node %p had null key", (void*)node);
         return 1;
     }
     if (node->l) {
         if (dict_sanity_check_node(node->l, dss)) return 1;
         if (irccasecmp(node->l->key, node->key) >= 0) {
-            snprintf(dss->error, sizeof(dss->error), "Node %p's left child's key '%s' >= its key '%s'", node, node->l->key, node->key);
+            snprintf(dss->error, sizeof(dss->error), "Node %p's left child's key '%s' >= its key '%s'", (void*)node, node->l->key, node->key);
             return 1;
         }
     }
     if (node->r) {
         if (dict_sanity_check_node(node->r, dss)) return 1;
         if (irccasecmp(node->key, node->r->key) >= 0) {
-            snprintf(dss->error, sizeof(dss->error), "Node %p's right child's key '%s' <= its key '%s'", node, node->r->key, node->key);
+            snprintf(dss->error, sizeof(dss->error), "Node %p's right child's key '%s' <= its key '%s'", (void*)node, node->r->key, node->key);
             return 1;
         }
     }
@@ -310,6 +336,7 @@ dict_sanity_check(dict_t dict)
     dss.node_count = 0;
     dss.bad_node = 0;
     dss.error[0] = 0;
+    verify(dict);
     if (dict->root && dict_sanity_check_node(dict->root, &dss)) {
         return strdup(dss.error);
     } else if (dss.node_count != dict->count) {