forked from UnixJunkie/vp-tree
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvp_tree.ml
309 lines (273 loc) · 10.2 KB
/
vp_tree.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
module A = MyArray
(* Vantage-point tree implementation
Cf. "Data structures and algorithms for nearest neighbor search
in general metric spaces" by Peter N. Yianilos for details.
http://citeseerx.ist.psu.edu/viewdoc/\
download?doi=10.1.1.41.4193&rep=rep1&type=pdf *)
(* Functorial interface *)
module type Point =
sig
type t
(* dist _must_ be a metric *)
val dist: t -> t -> float
end
module Make = functor (P: Point) ->
struct
type quality = Optimal (* if you have thousands of points *)
| Good of int (* if you have tens to hundreds
of thousands of points *)
| Random (* if you have millions of points *)
type node = { vp: P.t;
lb_low: float;
lb_high: float;
middle: float;
rb_low: float;
rb_high: float;
left: t;
right: t }
and t = Empty
| Node of node
let new_node vp lb_low lb_high middle rb_low rb_high left right =
Node { vp; lb_low; lb_high; middle; rb_low; rb_high; left; right }
type open_itv = { lbound: float; rbound: float }
let new_open_itv lbound rbound =
assert(lbound <= rbound);
{ lbound; rbound }
let in_open_itv x { lbound ; rbound } =
(x > lbound) && (x < rbound)
let itv_dont_overlap left right =
let a = left.lbound in
let b = left.rbound in
let c = right.lbound in
let d = right.rbound in
(* [a..b] [c..d] OR [c..d] [a..b] *)
(b < c) || (d < a)
let itv_overlap left right =
not (itv_dont_overlap left right)
let square (x: float): float =
x *. x
let float_compare (x: float) (y: float): int =
if x < y then -1
else if x > y then 1
else 0 (* x = y *)
let median (xs: float array): float =
A.sort float_compare xs;
let n = A.length xs in
if n mod 2 = 1 then xs.(n / 2)
else 0.5 *. (xs.(n / 2) +. xs.(n / 2 - 1))
let variance (mu: float) (xs: float array): float =
A.fold_left (fun acc x ->
acc +. (square (x -. mu))
) 0.0 xs
(* compute distance of point at index 'q_i' to all other points *)
let distances (q_i: int) (points: P.t array): float array =
let n = A.length points in
assert(n > 1);
let res = A.make (n - 1) 0.0 in
let j = ref 0 in
let q = points.(q_i) in
for i = 0 to n - 1 do
if i <> q_i then
(res.(!j) <- P.dist q points.(i);
incr j)
done;
res
(* this is optimal (slowest tree construction; O(n^2));
but fastest query time *)
let select_best_vp (points: P.t array) =
let n = A.length points in
if n = 0 then assert(false)
else if n = 1 then (points.(0), 0.0, [||])
else
let curr_vp = ref 0 in
let curr_mu = ref 0.0 in
let curr_spread = ref 0.0 in
for i = 0 to n - 1 do
(* could be faster using a distance cache? *)
let dists = distances !curr_vp points in
let mu = median dists in
let spread = variance mu dists in
if spread > !curr_spread then
(curr_vp := i;
curr_mu := mu;
curr_spread := spread)
done;
(points.(!curr_vp), !curr_mu, A.remove !curr_vp points)
(* to replace select_best_vp when working with too many points *)
let select_good_vp rng (sample_size: int) (points: P.t array) =
let n = A.length points in
if sample_size * sample_size >= n then
select_best_vp points
else
let candidates = A.bootstrap_sample rng sample_size points in
let curr_vp = ref 0 in
let curr_mu = ref 0.0 in
let curr_spread = ref 0.0 in
A.iteri (fun i p_i ->
let sample = A.bootstrap_sample rng sample_size points in
let dists = A.map (P.dist p_i) sample in
let mu = median dists in
let spread = variance mu dists in
if spread > !curr_spread then
(curr_vp := i;
curr_mu := mu;
curr_spread := spread)
) candidates;
(* we need the true mu to balance the tree;
not the one gotten from the sample! *)
let dists = distances !curr_vp points in
let mu = median dists in
(points.(!curr_vp), mu, A.remove !curr_vp points)
(* to replace select_good_vp when working with way too many points,
or if you really need the fastest possible tree construction *)
let select_rand_vp rng (points: P.t array) =
let n = A.length points in
assert(n > 0);
let vp = Random.State.int rng n in
let dists = distances vp points in
let mu = median dists in
(points.(vp), mu, A.remove vp points)
exception Empty_list
let rec create' select_vp points =
let n = A.length points in
if n = 0 then Empty
else if n = 1 then new_node points.(0) 0. 0. 0. 0. 0. Empty Empty
else
let vp, mu, others = select_vp points in
let dists = A.map (fun p -> (P.dist vp p, p)) others in
let lefties, righties = A.partition (fun (d, p) -> d < mu) dists in
let ldists, lpoints = A.split lefties in
let rdists, rpoints = A.split righties in
let lb_low, lb_high = A.min_max_def ldists (0., 0.) in
let rb_low, rb_high = A.min_max_def rdists (0., 0.) in
let middle = (lb_high +. rb_low) *. 0.5 in
new_node vp lb_low lb_high middle rb_low rb_high
(create' select_vp lpoints) (create' select_vp rpoints)
let create quality points =
let select_vp = match quality with
| Optimal -> select_best_vp
| Good ssize -> select_good_vp (Random.State.make_self_init ()) ssize
| Random -> select_rand_vp (Random.State.make_self_init ()) in
create' select_vp (A.of_list points)
let rec find_nearest acc query tree =
match tree with
| Empty -> acc
| Node { vp; lb_low; lb_high; middle; rb_low; rb_high; left; right } ->
let x = P.dist vp query in
if x = 0.0 then Some (x, vp) (* can't get nearer than that *)
else
let tau, acc' =
match acc with
| None -> (x, Some (x, vp))
| Some (tau, best) ->
if x < tau then (x, Some (x, vp))
else (tau, Some (tau, best)) in
let il = new_open_itv (lb_low -. tau) (lb_high +. tau) in
let ir = new_open_itv (rb_low -. tau) (rb_high +. tau) in
let in_il = in_open_itv x il in
let in_ir = in_open_itv x ir in
if x < middle then
match in_il, in_ir with
| false, false -> acc'
| true, false -> find_nearest acc' query left
| false, true -> find_nearest acc' query right
| true, true ->
(match find_nearest acc' query left with
| None -> find_nearest acc' query right
| Some (tau, best) ->
match find_nearest acc' query right with
| None -> Some (tau, best)
| Some (tau', best') ->
if tau' < tau then Some (tau', best')
else Some (tau, best))
else (* x >= middle *)
match in_ir, in_il with
| false, false -> acc'
| true, false -> find_nearest acc' query right
| false, true -> find_nearest acc' query left
| true, true ->
(match find_nearest acc' query right with
| None -> find_nearest acc' query left
| Some (tau, best) ->
match find_nearest acc' query left with
| None -> Some (tau, best)
| Some (tau', best') ->
if tau' < tau then Some (tau', best')
else Some (tau, best))
let nearest_neighbor query tree =
match find_nearest None query tree with
| Some x -> x
| None -> raise Not_found
let rec to_list_loop acc = function
| Empty -> acc
| Node n ->
let acc' = to_list_loop acc n.right in
to_list_loop (n.vp :: acc') n.left
let to_list tree =
to_list_loop [] tree
let neighbors query tol tree =
let rec loop acc = function
| Empty -> acc
| Node { vp; lb_low; lb_high; middle; rb_low; rb_high; left; right } ->
(* should we include vp? *)
let d = P.dist vp query in
let acc' = if d <= tol then vp :: acc else acc in
let lbound = max 0.0 (d -. tol) in
let rbound = d +. tol in
let itv = new_open_itv lbound rbound in
(* should we inspect the left? *)
let lmatches =
let itv_left = new_open_itv lb_low lb_high in
if itv_overlap itv itv_left then
(* further calls to P.dist needed? *)
if d +. lb_high <= tol then
(* all descendants are included *)
to_list_loop acc' left
else
loop acc' left
else acc' in
(* should we inspect the right? *)
let itv_right = new_open_itv rb_low rb_high in
if itv_overlap itv itv_right then
(* further calls to P.dist needed? *)
if d +. rb_high <= tol then
to_list_loop lmatches right
else
loop lmatches right
else lmatches in
loop [] tree
let is_empty = function
| Empty -> true
| Node _ -> false
let root = function
| Empty -> raise Not_found
| Node { vp; lb_low; lb_high; middle; rb_low; rb_high; left; right } -> vp
(* test if the tree invariant holds.
If it doesn't, then we are in trouble... *)
let rec check = function
| Empty -> true
| Node { vp; lb_low; lb_high; middle; rb_low; rb_high; left; right } ->
let bounds_OK = (0.0 <= lb_low) &&
(lb_low <= lb_high) &&
((lb_high < middle) || (0.0 = middle)) &&
(middle <= rb_low) &&
(rb_low <= rb_high) in
(bounds_OK &&
List.for_all (fun p -> P.dist vp p < middle) (to_list left) &&
List.for_all (fun p -> P.dist vp p >= middle) (to_list right) &&
check left && check right)
exception Found of P.t
let find query tree =
let rec loop = function
| Empty -> ()
| Node { vp; lb_low; lb_high; middle; rb_low; rb_high; left; right } ->
let d = P.dist vp query in
if d = 0.0 then raise (Found vp)
else if d < middle then loop left
else loop right in
try (loop tree; raise Not_found)
with Found p -> p
let mem query tree =
try let _ = find query tree in true
with Not_found -> false
end