kv_trie_rs/trie/
trie.rs

1use super::Trie;
2use louds_rs::LoudsNodeNum;
3
4impl<K: Ord + Clone, V: Clone> Trie<K, V> {
5    pub fn exact_match<Key: AsRef<[K]>>(&self, query: Key) -> bool {
6        let mut cur_node_num = LoudsNodeNum(1);
7
8        for (i, chr) in query.as_ref().iter().enumerate() {
9            let children_node_nums = self.children_node_nums(cur_node_num);
10            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
11
12            match res {
13                Ok(j) => {
14                    let child_node_num = children_node_nums[j];
15                    if i == query.as_ref().len() - 1 && self.is_terminal(child_node_num) {
16                        return true;
17                    };
18                    cur_node_num = child_node_num;
19                }
20                Err(_) => return false,
21            }
22        }
23        false
24    }
25
26    pub fn get<Key: AsRef<[K]>>(&self, query: Key) -> Option<&V> {
27        let mut cur_node_num = LoudsNodeNum(1);
28
29        for (i, chr) in query.as_ref().iter().enumerate() {
30            let children_node_nums = self.children_node_nums(cur_node_num);
31            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
32
33            match res {
34                Ok(j) => {
35                    let child_node_num = children_node_nums[j];
36                    if i == query.as_ref().len() - 1 && self.is_terminal(child_node_num) {
37                        let value_opts = &self.trie_labels[child_node_num.0 as usize - 2].value;
38                        return match value_opts {
39                            Some(value) => Some(value),
40                            None => None,
41                        };
42                    };
43                    cur_node_num = child_node_num;
44                }
45                Err(_) => return None,
46            }
47        }
48        None
49    }
50
51    pub fn get_mut<Key: AsRef<[K]>>(&mut self, query: Key) -> Option<&mut V> {
52        let mut cur_node_num = LoudsNodeNum(1);
53
54        for (i, chr) in query.as_ref().iter().enumerate() {
55            let children_node_nums = self.children_node_nums(cur_node_num);
56            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
57
58            match res {
59                Ok(j) => {
60                    let child_node_num = children_node_nums[j];
61                    if i == query.as_ref().len() - 1 && self.is_terminal(child_node_num) {
62                        let value_opts = &mut self.trie_labels[child_node_num.0 as usize - 2].value;
63                        return match value_opts {
64                            Some(ref mut value) => Some(value),
65                            None => None,
66                        };
67                    };
68                    cur_node_num = child_node_num;
69                }
70                Err(_) => {
71                    return None;
72                }
73            }
74        }
75        None
76    }
77
78    pub fn set<Key: AsRef<[K]>>(&mut self, query: Key, value: V) {
79        let mut cur_node_num = LoudsNodeNum(1);
80
81        for (i, chr) in query.as_ref().iter().enumerate() {
82            let children_node_nums = self.children_node_nums(cur_node_num);
83            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
84
85            if let Ok(j) = res {
86                let child_node_num = children_node_nums[j];
87                if i == query.as_ref().len() - 1 {
88                    self.trie_labels[child_node_num.0 as usize - 2].value = Some(value);
89                    return;
90                };
91                cur_node_num = child_node_num;
92            }
93        }
94    }
95
96    /// # Panics
97    /// If `query` is empty.
98    pub fn predictive_search<Arr: AsRef<[K]>>(&self, query: Arr) -> Vec<Vec<K>> {
99        self.rec_predictive_search(query, LoudsNodeNum(1))
100    }
101    fn rec_predictive_search<Arr: AsRef<[K]>>(
102        &self,
103        query: Arr,
104        node_num: LoudsNodeNum,
105    ) -> Vec<Vec<K>> {
106        assert!(!query.as_ref().is_empty());
107        let mut cur_node_num = node_num;
108
109        // Consumes query (prefix)
110        for chr in query.as_ref() {
111            let children_node_nums = self.children_node_nums(cur_node_num);
112            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
113            match res {
114                Ok(i) => cur_node_num = children_node_nums[i],
115                Err(_) => return vec![],
116            }
117        }
118
119        let mut results = if self.is_terminal(cur_node_num) {
120            vec![query.as_ref().to_vec()]
121        } else {
122            vec![]
123        };
124        let all_words_under_cur: Vec<Vec<K>> = self
125            .children_node_nums(cur_node_num)
126            .iter()
127            .flat_map(|child_node_num| {
128                self.rec_predictive_search(vec![self.label(*child_node_num)], cur_node_num)
129            })
130            .collect();
131
132        for word in all_words_under_cur {
133            let mut result: Vec<K> = query.as_ref().to_vec();
134            result.extend(word);
135            results.push(result);
136        }
137        results
138    }
139
140    pub fn common_prefix_search<Key: AsRef<[K]>>(&self, query: Key) -> Vec<Vec<K>> {
141        let mut results: Vec<Vec<K>> = Vec::new();
142        let mut labels_in_path: Vec<K> = Vec::new();
143
144        let mut cur_node_num = LoudsNodeNum(1);
145
146        for chr in query.as_ref() {
147            let children_node_nums = self.children_node_nums(cur_node_num);
148            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
149            match res {
150                Ok(j) => {
151                    let child_node_num = children_node_nums[j];
152                    labels_in_path.push(self.label(child_node_num));
153                    if self.is_terminal(child_node_num) {
154                        results.push(labels_in_path.clone());
155                    };
156                    cur_node_num = child_node_num;
157                }
158                Err(_) => break,
159            }
160        }
161        results
162    }
163
164    pub fn common_prefix_search_with_values<Key: AsRef<[K]>>(
165        &self,
166        query: Key,
167    ) -> Vec<(Vec<K>, V)> {
168        let mut results: Vec<(Vec<K>, V)> = Vec::new();
169        let mut labels_in_path: Vec<K> = Vec::new();
170
171        let mut cur_node_num = LoudsNodeNum(1);
172
173        for chr in query.as_ref() {
174            let children_node_nums = self.children_node_nums(cur_node_num);
175            let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
176            match res {
177                Ok(j) => {
178                    let child_node_num = children_node_nums[j];
179                    labels_in_path.push(self.label(child_node_num));
180                    if self.is_terminal(child_node_num) {
181                        match self.value(child_node_num) {
182                            Some(value) => results.push((labels_in_path.clone(), value)),
183                            None => panic!("Trie is inconsistent"),
184                        }
185                    };
186                    cur_node_num = child_node_num;
187                }
188                Err(_) => break,
189            }
190        }
191        results
192    }
193
194    fn children_node_nums(&self, node_num: LoudsNodeNum) -> Vec<LoudsNodeNum> {
195        self.louds
196            .parent_to_children(node_num)
197            .iter()
198            .map(|child_idx| self.louds.index_to_node_num(*child_idx))
199            .collect()
200    }
201
202    fn bin_search_by_children_labels(
203        &self,
204        query: &K,
205        children_node_nums: &[LoudsNodeNum],
206    ) -> Result<usize, usize> {
207        children_node_nums.binary_search_by_key(query, |child_node_num| self.label(*child_node_num))
208    }
209
210    fn label(&self, node_num: LoudsNodeNum) -> K {
211        self.trie_labels[(node_num.0 - 2) as usize].key.clone()
212    }
213
214    fn value(&self, node_num: LoudsNodeNum) -> Option<V> {
215        self.trie_labels[(node_num.0 - 2) as usize].value.clone()
216    }
217
218    fn is_terminal(&self, node_num: LoudsNodeNum) -> bool {
219        self.trie_labels[(node_num.0 - 2) as usize].is_terminal
220    }
221}
222
223#[cfg(test)]
224mod search_tests {
225    use crate::{Trie, TrieBuilder};
226
227    fn build_trie() -> Trie<u8, String> {
228        let mut builder = TrieBuilder::new();
229        builder.push("a", "random_value_1".to_string());
230        builder.push("app", "random_value_2".to_string());
231        builder.push("apple", "random_value_3".to_string());
232        builder.push("better", "random_value_4".to_string());
233        builder.push("application", "random_value_5".to_string());
234        builder.push("アップル🍎", "random_value_6".to_string());
235        builder.build()
236    }
237
238    fn build_trie_mut() -> Trie<u8, String> {
239        let mut builder = TrieBuilder::new();
240        builder.push("a", "random_value_1".to_string());
241        builder.push("a", "".to_string());
242        builder.push("app", "random_value_2".to_string());
243        builder.push("app", "random_value_3".to_string());
244        builder.push("apple", "random_value_4".to_string());
245        builder.build()
246    }
247
248    #[test]
249    fn test_common_prefix_with_values_search() {
250        let trie = build_trie();
251        let result = trie.common_prefix_search_with_values("apple");
252        assert_eq!(result.len(), 3);
253        assert_eq!(result[0].0, b"a".to_vec());
254        assert_eq!(result[0].1, "random_value_1".to_string());
255        assert_eq!(result[1].0, b"app".to_vec());
256        assert_eq!(result[1].1, "random_value_2".to_string());
257        assert_eq!(result[2].0, b"apple".to_vec());
258        assert_eq!(result[2].1, "random_value_3".to_string());
259    }
260
261    #[test]
262    fn test_get_mut() {
263        let mut trie = build_trie_mut();
264        let result = trie.get_mut("a");
265        assert_eq!(result.unwrap(), &mut "".to_string());
266        let result = trie.get_mut("apple");
267        assert_eq!(result.unwrap(), &mut "random_value_4".to_string());
268        let result = trie.get_mut("app");
269        assert_eq!(result.unwrap(), &mut "random_value_3".to_string());
270    }
271
272    #[test]
273    fn test_get() {
274        let trie = build_trie();
275        let result = trie.get("better");
276        assert_eq!(result.unwrap(), &"random_value_4".to_string());
277    }
278
279    #[test]
280    fn test_set_multiple() {
281        let mut builder = TrieBuilder::new();
282        let mut contents = vec!["1", "2", "3", "4"];
283        builder.push("a", vec!["x", "y"]);
284        builder.push("axe", contents.clone());
285        contents.push("5");
286        let mut trie = builder.build();
287        trie.set("axe", contents);
288        assert_eq!(trie.get("a").unwrap(), &vec!["x", "y"]);
289        assert_eq!(trie.get("ax"), None);
290        assert_eq!(trie.get("axe").unwrap(), &vec!["1", "2", "3", "4", "5"]);
291    }
292
293    mod exact_match_tests {
294        macro_rules! parameterized_tests {
295            ($($name:ident: $value:expr,)*) => {
296            $(
297                #[test]
298                fn $name() {
299                    let (query, expected_match) = $value;
300                    let trie = super::build_trie();
301                    let result = trie.exact_match(query);
302                    assert_eq!(result, expected_match);
303                }
304            )*
305            }
306        }
307
308        parameterized_tests! {
309            t1: ("a", true),
310            t2: ("app", true),
311            t3: ("apple", true),
312            t4: ("application", true),
313            t5: ("better", true),
314            t6: ("アップル🍎", true),
315            t7: ("appl", false),
316            t8: ("appler", false),
317        }
318    }
319
320    mod predictive_search_tests {
321        macro_rules! parameterized_tests {
322            ($($name:ident: $value:expr,)*) => {
323            $(
324                #[test]
325                fn $name() {
326                    let (query, expected_results) = $value;
327                    let trie = super::build_trie();
328                    let results = trie.predictive_search(query);
329                    let expected_results: Vec<Vec<u8>> = expected_results.iter().map(|s| s.as_bytes().to_vec()).collect();
330                    assert_eq!(results, expected_results);
331                }
332            )*
333            }
334        }
335
336        parameterized_tests! {
337            t1: ("a", vec!["a", "app", "apple", "application"]),
338            t2: ("app", vec!["app", "apple", "application"]),
339            t3: ("appl", vec!["apple", "application"]),
340            t4: ("apple", vec!["apple"]),
341            t5: ("b", vec!["better"]),
342            t6: ("c", Vec::<&str>::new()),
343            t7: ("アップ", vec!["アップル🍎"]),
344        }
345    }
346
347    mod common_prefix_search_tests {
348        macro_rules! parameterized_tests {
349            ($($name:ident: $value:expr,)*) => {
350            $(
351                #[test]
352                fn $name() {
353                    let (query, expected_results) = $value;
354                    let trie = super::build_trie();
355                    let results = trie.common_prefix_search(query);
356                    let expected_results: Vec<Vec<u8>> = expected_results.iter().map(|s| s.as_bytes().to_vec()).collect();
357                    assert_eq!(results, expected_results);
358                }
359            )*
360            }
361        }
362
363        parameterized_tests! {
364            t1: ("a", vec!["a"]),
365            t2: ("ap", vec!["a"]),
366            t3: ("appl", vec!["a", "app"]),
367            t4: ("appler", vec!["a", "app", "apple"]),
368            t5: ("bette", Vec::<&str>::new()),
369            t6: ("betterment", vec!["better"]),
370            t7: ("c", Vec::<&str>::new()),
371            t8: ("アップル🍎🍏", vec!["アップル🍎"]),
372        }
373    }
374}