1#include "../include/trie.h"
13typedef struct trie_node {
15 struct trie_node** children;
32static trie_node* node_create(
void) {
33 trie_node* n = (trie_node*)calloc(1,
sizeof(trie_node));
38static void node_destroy(trie_node* n) {
40 for (uint8_t i = 0; i < n->nchildren; i++) node_destroy(n->children[i]);
54static inline int child_index(
const trie_node* n, uint8_t c) {
55 int lo = 0, hi = (int)n->nchildren - 1;
57 int mid = (lo + hi) >> 1;
58 if (n->chars[mid] == c)
60 else if (n->chars[mid] < c)
73static trie_node* child_find_or_create(trie_node* n, uint8_t c) {
74 int idx = child_index(n, c);
75 if (idx >= 0)
return n->children[idx];
78 if (n->nchildren == n->nalloc) {
79 uint8_t new_alloc = n->nalloc == 0 ? 2 : (uint8_t)(n->nalloc * 2);
80 if (new_alloc < n->nalloc) new_alloc = 255;
82 uint8_t* nc = (uint8_t*)realloc(n->chars, new_alloc *
sizeof(uint8_t));
83 trie_node** np = (trie_node**)realloc(n->children, new_alloc *
sizeof(trie_node*));
93 n->nalloc = new_alloc;
97 int8_t pos = (int8_t)n->nchildren;
98 while (pos > 0 && n->chars[pos - 1] > c) {
99 n->chars[pos] = n->chars[pos - 1];
100 n->children[pos] = n->children[pos - 1];
104 trie_node* child = node_create();
105 if (!child)
return NULL;
108 n->children[pos] = child;
120 t->root = node_create();
131 node_destroy(t->root);
136 if (!t || !word || !*word)
return false;
138 trie_node* cur = t->root;
139 for (
const uint8_t* p = (
const uint8_t*)word; *p; p++) {
140 cur = child_find_or_create(cur, *p);
141 if (!cur)
return false;
144 if (!cur->is_end_of_word) {
145 cur->is_end_of_word =
true;
153 if (!t || !word || !*word)
return false;
155 const trie_node* cur = t->root;
156 for (
const uint8_t* p = (
const uint8_t*)word; *p; p++) {
157 int idx = child_index(cur, *p);
158 if (idx < 0)
return false;
159 cur = cur->children[idx];
161 return cur->is_end_of_word;
165 if (!t || !prefix || !*prefix)
return false;
167 const trie_node* cur = t->root;
168 for (
const uint8_t* p = (
const uint8_t*)prefix; *p; p++) {
169 int idx = child_index(cur, *p);
170 if (idx < 0)
return false;
171 cur = cur->children[idx];
177 if (!t || !word || !*word)
return false;
179 const trie_node* cur = t->root;
180 for (
const uint8_t* p = (
const uint8_t*)word; *p; p++) {
181 int idx = child_index(cur, *p);
182 if (idx < 0)
return false;
183 cur = cur->children[idx];
186 if (!cur->is_end_of_word)
return false;
187 ((trie_node*)cur)->is_end_of_word =
false;
188 ((trie_node*)cur)->frequency = 0;
194 if (!t || !word || !*word)
return 0;
196 const trie_node* cur = t->root;
197 for (
const uint8_t* p = (
const uint8_t*)word; *p; p++) {
198 int idx = child_index(cur, *p);
199 if (idx < 0)
return 0;
200 cur = cur->children[idx];
202 return cur->is_end_of_word ? cur->frequency : 0;
219static void _collect(
const trie_node* node,
char* buf,
size_t depth,
size_t buf_max, _collector* c, Arena* arena) {
220 if (!node || c->count >= c->limit)
return;
222 if (node->is_end_of_word) {
224 char* dup = arena_strdup(arena, buf);
225 if (dup) c->suggestions[c->count++] = dup;
228 for (uint8_t i = 0; i < node->nchildren && c->count < c->limit; i++) {
229 if (depth + 1 >= buf_max)
return;
230 buf[depth] = (char)node->chars[i];
231 _collect(node->children[i], buf, depth + 1, buf_max, c, arena);
237 if (out_count) *out_count = 0;
238 if (!t || !prefix || !out_count || max_suggestions == 0 || !arena)
return NULL;
241 const trie_node* cur = t->root;
242 for (
const uint8_t* p = (
const uint8_t*)prefix; *p; p++) {
243 int idx = child_index(cur, *p);
244 if (idx < 0)
return NULL;
245 cur = cur->children[idx];
249 char** suggestions = ARENA_ALLOC_ARRAY(arena,
char*, max_suggestions);
250 if (!suggestions)
return NULL;
252 const size_t buf_max = 1024;
253 char* buf = (
char*)arena_alloc(arena, buf_max);
254 if (!buf)
return NULL;
256 size_t prefix_len = strlen(prefix);
257 if (prefix_len >= buf_max)
return NULL;
258 memcpy(buf, prefix, prefix_len);
260 _collector c = {.suggestions = suggestions, .count = 0, .capacity = max_suggestions, .limit = max_suggestions};
262 _collect(cur, buf, prefix_len, buf_max, &c, arena);
264 if (c.count == 0)
return NULL;
266 *out_count = c.count;
267 return (
const char**)suggestions;
void trie_destroy(trie_t *trie)
bool trie_starts_with(const trie_t *trie, const char *prefix)
bool trie_search(const trie_t *trie, const char *word)
bool trie_is_empty(const trie_t *trie)
bool trie_insert(trie_t *trie, const char *word)
uint32_t trie_get_frequency(const trie_t *trie, const char *word)
size_t trie_get_word_count(const trie_t *trie)
trie_t * trie_create(void)
const char ** trie_autocomplete(const trie_t *trie, const char *prefix, size_t max_suggestions, size_t *out_count, Arena *arena)
bool trie_delete(trie_t *trie, const char *word)