#include #include struct node { const void *key; void *a[2]; int h; }; static int height(struct node *n) { return n ? n->h : 0; } static struct node *rot(struct node *x, int dir) { struct node *y = x->a[!dir]; struct node *z = y->a[dir]; int hz = height(z); if (hz > height(y->a[!dir])) { // x // dir / \ !dir z // A y / \ // / \ --> x y // z D /| |\ // / \ A B C D // B C x->a[!dir] = z->a[dir]; y->a[dir] = z->a[!dir]; z->a[dir] = x; z->a[!dir] = y; x->h = hz; y->h = hz; z->h = hz+1; } else { // x y // / \ / \ // A y --> x D // / \ / \ // z D A z x->a[!dir] = z; y->a[dir] = x; x->h = hz+1; y->h = hz+2; z = y; } return z; } static struct node *balance(struct node *n) { int h0 = height(n->a[0]); int h1 = height(n->a[1]); if (h0-h1+1u < 3) { n->h = h0>h1 ? h0+1 : h1+1; return n; } return rot(n, h0>h1); } static struct node *remove_rightmost(struct node *n, const void **pkey) { if (!n->a[1]) { struct node *left = n->a[0]; *pkey = n->key; free(n); return left; } n->a[1] = remove_rightmost(n->a[1], pkey); return balance(n); } static struct node *remove(const void *key, void **p, int (*cmp)(const void *, const void *), struct node *parent) { struct node *n = *p; if (!n) return 0; int c = cmp(key, n->key); if (!c) { if (n->a[0]) { n->a[0] = remove_rightmost(n->a[0], &n->key); *p = balance(n); } else { *p = n->a[1]; free(n); } return parent; } parent = remove(key, &n->a[c>0], cmp, n); if (parent) *p = balance(n); return parent; } void *tdelete(const void *restrict key, void **restrict rootp, int(*cmp)(const void *, const void *)) { if (!rootp) return 0; /* last argument is arbitrary non-null pointer which is returned when the root node is deleted */ return remove(key, rootp, cmp, *rootp); } static struct node *insert(const void *key, void **p, int (*cmp)(const void *, const void *), struct node **found) { struct node *n = *p; struct node *r; if (!n) { n = malloc(sizeof *n); if (n) { n->key = key; n->a[0] = n->a[1] = 0; n->h = 1; } *p = n; *found = n; return n; } int c = cmp(key, n->key); if (!c) { *found = n; return 0; } r = insert(key, &n->a[c>0], cmp, found); if (r) *p = balance(n); return r; } void *tsearch(const void *key, void **rootp, int (*cmp)(const void *, const void *)) { if (!rootp) return 0; struct node *found; insert(key, rootp, cmp, &found); return found; } static struct node *find(const void *key, struct node *n, int (*cmp)(const void *, const void *)) { if (!n) return 0; int c = cmp(key, n->key); if (!c) return n; return find(key, n->a[c>0], cmp); } void *tfind(const void *key, void *const *rootp, int(*cmp)(const void *, const void *)) { if (!rootp) return 0; return find(key, *rootp, cmp); } static void walk(const struct node *r, void (*action)(const void *, VISIT, int), int d) { if (!r) return; if (r->h == 1) action(r, leaf, d); else { action(r, preorder, d); walk(r->a[0], action, d+1); action(r, postorder, d); walk(r->a[1], action, d+1); action(r, endorder, d); } } void twalk(const void *root, void (*action)(const void *, VISIT, int)) { walk(root, action, 0); }