1use super::Trie;
2use louds_rs::LoudsNodeNum;
3
4impl<K: Ord + Clone, V: Clone> Trie<K, V> {
5 pub fn exact_match<Key: AsRef<[K]>>(&self, query: Key) -> bool {
6 let mut cur_node_num = LoudsNodeNum(1);
7
8 for (i, chr) in query.as_ref().iter().enumerate() {
9 let children_node_nums = self.children_node_nums(cur_node_num);
10 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
11
12 match res {
13 Ok(j) => {
14 let child_node_num = children_node_nums[j];
15 if i == query.as_ref().len() - 1 && self.is_terminal(child_node_num) {
16 return true;
17 };
18 cur_node_num = child_node_num;
19 }
20 Err(_) => return false,
21 }
22 }
23 false
24 }
25
26 pub fn get<Key: AsRef<[K]>>(&self, query: Key) -> Option<&V> {
27 let mut cur_node_num = LoudsNodeNum(1);
28
29 for (i, chr) in query.as_ref().iter().enumerate() {
30 let children_node_nums = self.children_node_nums(cur_node_num);
31 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
32
33 match res {
34 Ok(j) => {
35 let child_node_num = children_node_nums[j];
36 if i == query.as_ref().len() - 1 && self.is_terminal(child_node_num) {
37 let value_opts = &self.trie_labels[child_node_num.0 as usize - 2].value;
38 return match value_opts {
39 Some(value) => Some(value),
40 None => None,
41 };
42 };
43 cur_node_num = child_node_num;
44 }
45 Err(_) => return None,
46 }
47 }
48 None
49 }
50
51 pub fn get_mut<Key: AsRef<[K]>>(&mut self, query: Key) -> Option<&mut V> {
52 let mut cur_node_num = LoudsNodeNum(1);
53
54 for (i, chr) in query.as_ref().iter().enumerate() {
55 let children_node_nums = self.children_node_nums(cur_node_num);
56 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
57
58 match res {
59 Ok(j) => {
60 let child_node_num = children_node_nums[j];
61 if i == query.as_ref().len() - 1 && self.is_terminal(child_node_num) {
62 let value_opts = &mut self.trie_labels[child_node_num.0 as usize - 2].value;
63 return match value_opts {
64 Some(ref mut value) => Some(value),
65 None => None,
66 };
67 };
68 cur_node_num = child_node_num;
69 }
70 Err(_) => {
71 return None;
72 }
73 }
74 }
75 None
76 }
77
78 pub fn set<Key: AsRef<[K]>>(&mut self, query: Key, value: V) {
79 let mut cur_node_num = LoudsNodeNum(1);
80
81 for (i, chr) in query.as_ref().iter().enumerate() {
82 let children_node_nums = self.children_node_nums(cur_node_num);
83 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
84
85 if let Ok(j) = res {
86 let child_node_num = children_node_nums[j];
87 if i == query.as_ref().len() - 1 {
88 self.trie_labels[child_node_num.0 as usize - 2].value = Some(value);
89 return;
90 };
91 cur_node_num = child_node_num;
92 }
93 }
94 }
95
96 pub fn predictive_search<Arr: AsRef<[K]>>(&self, query: Arr) -> Vec<Vec<K>> {
99 self.rec_predictive_search(query, LoudsNodeNum(1))
100 }
101 fn rec_predictive_search<Arr: AsRef<[K]>>(
102 &self,
103 query: Arr,
104 node_num: LoudsNodeNum,
105 ) -> Vec<Vec<K>> {
106 assert!(!query.as_ref().is_empty());
107 let mut cur_node_num = node_num;
108
109 for chr in query.as_ref() {
111 let children_node_nums = self.children_node_nums(cur_node_num);
112 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
113 match res {
114 Ok(i) => cur_node_num = children_node_nums[i],
115 Err(_) => return vec![],
116 }
117 }
118
119 let mut results = if self.is_terminal(cur_node_num) {
120 vec![query.as_ref().to_vec()]
121 } else {
122 vec![]
123 };
124 let all_words_under_cur: Vec<Vec<K>> = self
125 .children_node_nums(cur_node_num)
126 .iter()
127 .flat_map(|child_node_num| {
128 self.rec_predictive_search(vec![self.label(*child_node_num)], cur_node_num)
129 })
130 .collect();
131
132 for word in all_words_under_cur {
133 let mut result: Vec<K> = query.as_ref().to_vec();
134 result.extend(word);
135 results.push(result);
136 }
137 results
138 }
139
140 pub fn common_prefix_search<Key: AsRef<[K]>>(&self, query: Key) -> Vec<Vec<K>> {
141 let mut results: Vec<Vec<K>> = Vec::new();
142 let mut labels_in_path: Vec<K> = Vec::new();
143
144 let mut cur_node_num = LoudsNodeNum(1);
145
146 for chr in query.as_ref() {
147 let children_node_nums = self.children_node_nums(cur_node_num);
148 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
149 match res {
150 Ok(j) => {
151 let child_node_num = children_node_nums[j];
152 labels_in_path.push(self.label(child_node_num));
153 if self.is_terminal(child_node_num) {
154 results.push(labels_in_path.clone());
155 };
156 cur_node_num = child_node_num;
157 }
158 Err(_) => break,
159 }
160 }
161 results
162 }
163
164 pub fn common_prefix_search_with_values<Key: AsRef<[K]>>(
165 &self,
166 query: Key,
167 ) -> Vec<(Vec<K>, V)> {
168 let mut results: Vec<(Vec<K>, V)> = Vec::new();
169 let mut labels_in_path: Vec<K> = Vec::new();
170
171 let mut cur_node_num = LoudsNodeNum(1);
172
173 for chr in query.as_ref() {
174 let children_node_nums = self.children_node_nums(cur_node_num);
175 let res = self.bin_search_by_children_labels(chr, &children_node_nums[..]);
176 match res {
177 Ok(j) => {
178 let child_node_num = children_node_nums[j];
179 labels_in_path.push(self.label(child_node_num));
180 if self.is_terminal(child_node_num) {
181 match self.value(child_node_num) {
182 Some(value) => results.push((labels_in_path.clone(), value)),
183 None => panic!("Trie is inconsistent"),
184 }
185 };
186 cur_node_num = child_node_num;
187 }
188 Err(_) => break,
189 }
190 }
191 results
192 }
193
194 fn children_node_nums(&self, node_num: LoudsNodeNum) -> Vec<LoudsNodeNum> {
195 self.louds
196 .parent_to_children(node_num)
197 .iter()
198 .map(|child_idx| self.louds.index_to_node_num(*child_idx))
199 .collect()
200 }
201
202 fn bin_search_by_children_labels(
203 &self,
204 query: &K,
205 children_node_nums: &[LoudsNodeNum],
206 ) -> Result<usize, usize> {
207 children_node_nums.binary_search_by_key(query, |child_node_num| self.label(*child_node_num))
208 }
209
210 fn label(&self, node_num: LoudsNodeNum) -> K {
211 self.trie_labels[(node_num.0 - 2) as usize].key.clone()
212 }
213
214 fn value(&self, node_num: LoudsNodeNum) -> Option<V> {
215 self.trie_labels[(node_num.0 - 2) as usize].value.clone()
216 }
217
218 fn is_terminal(&self, node_num: LoudsNodeNum) -> bool {
219 self.trie_labels[(node_num.0 - 2) as usize].is_terminal
220 }
221}
222
223#[cfg(test)]
224mod search_tests {
225 use crate::{Trie, TrieBuilder};
226
227 fn build_trie() -> Trie<u8, String> {
228 let mut builder = TrieBuilder::new();
229 builder.push("a", "random_value_1".to_string());
230 builder.push("app", "random_value_2".to_string());
231 builder.push("apple", "random_value_3".to_string());
232 builder.push("better", "random_value_4".to_string());
233 builder.push("application", "random_value_5".to_string());
234 builder.push("アップル🍎", "random_value_6".to_string());
235 builder.build()
236 }
237
238 fn build_trie_mut() -> Trie<u8, String> {
239 let mut builder = TrieBuilder::new();
240 builder.push("a", "random_value_1".to_string());
241 builder.push("a", "".to_string());
242 builder.push("app", "random_value_2".to_string());
243 builder.push("app", "random_value_3".to_string());
244 builder.push("apple", "random_value_4".to_string());
245 builder.build()
246 }
247
248 #[test]
249 fn test_common_prefix_with_values_search() {
250 let trie = build_trie();
251 let result = trie.common_prefix_search_with_values("apple");
252 assert_eq!(result.len(), 3);
253 assert_eq!(result[0].0, b"a".to_vec());
254 assert_eq!(result[0].1, "random_value_1".to_string());
255 assert_eq!(result[1].0, b"app".to_vec());
256 assert_eq!(result[1].1, "random_value_2".to_string());
257 assert_eq!(result[2].0, b"apple".to_vec());
258 assert_eq!(result[2].1, "random_value_3".to_string());
259 }
260
261 #[test]
262 fn test_get_mut() {
263 let mut trie = build_trie_mut();
264 let result = trie.get_mut("a");
265 assert_eq!(result.unwrap(), &mut "".to_string());
266 let result = trie.get_mut("apple");
267 assert_eq!(result.unwrap(), &mut "random_value_4".to_string());
268 let result = trie.get_mut("app");
269 assert_eq!(result.unwrap(), &mut "random_value_3".to_string());
270 }
271
272 #[test]
273 fn test_get() {
274 let trie = build_trie();
275 let result = trie.get("better");
276 assert_eq!(result.unwrap(), &"random_value_4".to_string());
277 }
278
279 #[test]
280 fn test_set_multiple() {
281 let mut builder = TrieBuilder::new();
282 let mut contents = vec!["1", "2", "3", "4"];
283 builder.push("a", vec!["x", "y"]);
284 builder.push("axe", contents.clone());
285 contents.push("5");
286 let mut trie = builder.build();
287 trie.set("axe", contents);
288 assert_eq!(trie.get("a").unwrap(), &vec!["x", "y"]);
289 assert_eq!(trie.get("ax"), None);
290 assert_eq!(trie.get("axe").unwrap(), &vec!["1", "2", "3", "4", "5"]);
291 }
292
293 mod exact_match_tests {
294 macro_rules! parameterized_tests {
295 ($($name:ident: $value:expr,)*) => {
296 $(
297 #[test]
298 fn $name() {
299 let (query, expected_match) = $value;
300 let trie = super::build_trie();
301 let result = trie.exact_match(query);
302 assert_eq!(result, expected_match);
303 }
304 )*
305 }
306 }
307
308 parameterized_tests! {
309 t1: ("a", true),
310 t2: ("app", true),
311 t3: ("apple", true),
312 t4: ("application", true),
313 t5: ("better", true),
314 t6: ("アップル🍎", true),
315 t7: ("appl", false),
316 t8: ("appler", false),
317 }
318 }
319
320 mod predictive_search_tests {
321 macro_rules! parameterized_tests {
322 ($($name:ident: $value:expr,)*) => {
323 $(
324 #[test]
325 fn $name() {
326 let (query, expected_results) = $value;
327 let trie = super::build_trie();
328 let results = trie.predictive_search(query);
329 let expected_results: Vec<Vec<u8>> = expected_results.iter().map(|s| s.as_bytes().to_vec()).collect();
330 assert_eq!(results, expected_results);
331 }
332 )*
333 }
334 }
335
336 parameterized_tests! {
337 t1: ("a", vec!["a", "app", "apple", "application"]),
338 t2: ("app", vec!["app", "apple", "application"]),
339 t3: ("appl", vec!["apple", "application"]),
340 t4: ("apple", vec!["apple"]),
341 t5: ("b", vec!["better"]),
342 t6: ("c", Vec::<&str>::new()),
343 t7: ("アップ", vec!["アップル🍎"]),
344 }
345 }
346
347 mod common_prefix_search_tests {
348 macro_rules! parameterized_tests {
349 ($($name:ident: $value:expr,)*) => {
350 $(
351 #[test]
352 fn $name() {
353 let (query, expected_results) = $value;
354 let trie = super::build_trie();
355 let results = trie.common_prefix_search(query);
356 let expected_results: Vec<Vec<u8>> = expected_results.iter().map(|s| s.as_bytes().to_vec()).collect();
357 assert_eq!(results, expected_results);
358 }
359 )*
360 }
361 }
362
363 parameterized_tests! {
364 t1: ("a", vec!["a"]),
365 t2: ("ap", vec!["a"]),
366 t3: ("appl", vec!["a", "app"]),
367 t4: ("appler", vec!["a", "app", "apple"]),
368 t5: ("bette", Vec::<&str>::new()),
369 t6: ("betterment", vec!["better"]),
370 t7: ("c", Vec::<&str>::new()),
371 t8: ("アップル🍎🍏", vec!["アップル🍎"]),
372 }
373 }
374}