00001 #ifndef KDTREE_HEADER_
00002 #define KDTREE_HEADER_
00003
00004 #ifdef NDEBUG
00005 #include <cstring>
00006 #endif
00007
00014 namespace NOSPACE {
00015 using namespace std;
00016 }
00017
00019 namespace FieldMath {
00021 template<class I1,class I2,class Transf> inline
00022 Transf transform2( I1 i1, I1 iEnd1, I2 i2, Transf transformer ) {
00023 for (; i1!=iEnd1; ++i1,++i2)
00024 transformer(*i1,*i2);
00025 return transformer;
00026 }
00028 template<class I1,class I2,class I3,class Transf> inline
00029 Transf transform3( I1 i1, I1 iEnd1, I2 i2, I3 i3, Transf transformer ) {
00030 for (; i1!=iEnd1; ++i1,++i2,++i3)
00031 transformer(*i1,*i2,*i3);
00032 return transformer;
00033 }
00034
00036 template<class T> inline T* assign(const T *a,int length,T *b) {
00037 #ifndef NDEBUG
00038 copy( a, a+length, b );
00039 #else
00040 memcpy( b, a, length*sizeof(T) );
00041 #endif
00042 return b;
00043 }
00044
00045 namespace NOSPACE {
00047 template<class T,bool CheckNaNs> struct MoveToBounds {
00048 T sqrError;
00049
00051 MoveToBounds()
00052 : sqrError(0) {}
00053
00056 void operator()(const T &point,const T bounds[2],T &result) {
00057 if ( CheckNaNs && isNaN(point) )
00058 return;
00059 if ( point < bounds[0] ) {
00060 sqrError+= sqr(bounds[0]-point);
00061 result= bounds[0];
00062 } else
00063 if ( point > bounds[1] ) {
00064 sqrError+= sqr(point-bounds[1]);
00065 result= bounds[1];
00066 } else
00067 result= point;
00068 }
00069 };
00070 }
00073 template<class T,bool CheckNaNs> inline
00074 T moveToBounds_copy(const T *point,const T (*bounds)[2],int length,T *result) {
00075 return transform3( point, point+length, bounds, result
00076 , MoveToBounds<T,CheckNaNs>() ) .sqrError;
00077 }
00078 }
00079
00080 template<class T> class KDBuilder;
00083 template<class T> class KDTree {
00084 public:
00085 friend class KDBuilder<T>;
00086 typedef KDBuilder<T> Builder;
00087 protected:
00089 struct Node {
00090 int coord;
00091 T threshold;
00093 };
00094 typedef T (*Bounds)[2];
00095
00096 public:
00097 const int depth
00098 , length
00099 , count;
00100 protected:
00101 Node *nodes;
00102 int *dataIDs;
00103 Bounds bounds;
00104
00106 KDTree(int length_,int count_)
00107 : depth( log2ceil(count_) ), length(length_), count(count_)
00108 , nodes( new Node[count_] ), dataIDs( new int[count_] ), bounds( new T[length_][2] ) {
00109 }
00110
00112 KDTree(KDTree &other)
00113 : depth(other.depth), length(other.length), count(other.count)
00114 , nodes(other.nodes), dataIDs(other.dataIDs), bounds(other.bounds) {
00115 other.nodes= 0;
00116 other.dataIDs= 0;
00117 other.bounds= 0;
00118 }
00119
00122 int leafID2dataID(int leafID) const {
00123 ASSERT( count<=leafID && leafID<2*count );
00124 int index= leafID-powers[depth];
00125 if (index<0)
00126 index+= count;
00127 ASSERT( 0<=index && index<count );
00128 return dataIDs[index];
00129 }
00130
00131 public:
00133 ~KDTree() {
00134
00135 delete[] nodes;
00136 delete[] dataIDs;
00137 delete[] bounds;
00138 }
00139
00144 class PointHeap {
00146 struct HeapNode {
00147 int nodeIndex;
00148 T *data;
00151 HeapNode() {}
00153 HeapNode(int nodeIndex_,T *data_): nodeIndex(nodeIndex_), data(data_) {}
00154
00156 T& getSE() { return *data; }
00158 T getSE() const { return *data; }
00160 T* getNearest() { return data+1; }
00161 };
00163 struct HeapOrder {
00164 bool operator()(const HeapNode &a,const HeapNode &b)
00165 { return a.getSE() > b.getSE(); }
00166 };
00167
00168 const KDTree &kd;
00169 const T* const point;
00170 vector<HeapNode> heap;
00171 BulkAllocator<T> allocator;
00172 public:
00175 PointHeap(const KDTree &tree,const T *point_,bool checkNaNs)
00176 : kd(tree), point(point_) {
00177 ASSERT(point);
00178
00179 HeapNode rootNode( 1, allocator.makeField(kd.length+1) );
00180
00181 using namespace FieldMath;
00182 rootNode.getSE()= checkNaNs
00183 ? moveToBounds_copy<T,true> ( point, kd.bounds, kd.length, rootNode.getNearest() )
00184 : moveToBounds_copy<T,false>( point, kd.bounds, kd.length, rootNode.getNearest() );
00185
00186 heap.reserve(kd.depth*2);
00187 heap.push_back(rootNode);
00188 }
00189
00191 bool isEmpty()
00192 { return heap.empty(); }
00193
00195 T getTopSE() {
00196 ASSERT( !isEmpty() );
00197 return heap[0].getSE();
00198 }
00199
00202 template<bool CheckNaNs> int popLeaf(T maxSE) {
00203
00204 makeTopLeaf<CheckNaNs>(maxSE);
00205 int result= kd.leafID2dataID( heap.front().nodeIndex );
00206
00207 pop_heap( heap.begin(), heap.end(), HeapOrder() );
00208 heap.pop_back();
00209 return result;
00210 }
00211 protected:
00214 template<bool CheckNaNs> void makeTopLeaf(T maxSE);
00215
00216 };
00217 };
00218
00219
00220 template<class T> template<bool CheckNaNs>
00221 void KDTree<T>::PointHeap::makeTopLeaf(T maxSE) {
00222 ASSERT( !isEmpty() );
00223
00224 if ( heap[0].nodeIndex >= kd.count )
00225 return;
00226 PtrInt oldHeapSize= heap.size();
00227 HeapNode heapRoot= heap[0];
00228
00229 do {
00230 const Node &node= kd.nodes[heapRoot.nodeIndex];
00231
00232
00233
00234
00235 bool validCoord= !CheckNaNs || !isNaN(point[node.coord]);
00236
00237 T newSE;
00238 bool goRight;
00239 if (validCoord) {
00240 Real oldDiff= Real(point[node.coord]) - heapRoot.getNearest()[node.coord];
00241 Real newDiff= Real(point[node.coord]) - node.threshold;
00242 goRight= newDiff>0;
00243 newSE= heapRoot.getSE() - sqr(oldDiff) + sqr(newDiff);
00244 ASSERT( newSE >= heapRoot.getSE() );
00245 } else {
00246 newSE= heapRoot.getSE();
00247 goRight= false;
00248 }
00249
00250 heapRoot.nodeIndex= heapRoot.nodeIndex*2 + goRight;
00251
00252 if (newSE>maxSE)
00253 continue;
00254
00255 HeapNode newHNode;
00256 newHNode.data= allocator.makeField(kd.length+1);
00257 newHNode.getSE()= newSE;
00258 newHNode.nodeIndex= heapRoot.nodeIndex-goRight+!goRight;
00259
00260 FieldMath::assign( heapRoot.getNearest(), kd.length, newHNode.getNearest() );
00261 if (validCoord)
00262 newHNode.getNearest()[node.coord]= node.threshold;
00263
00264 heap.push_back(newHNode);
00265
00266 } while ( heapRoot.nodeIndex < kd.count );
00267
00268 heap[0]= heapRoot;
00269
00270 typename vector<HeapNode>::iterator it= heap.begin()+oldHeapSize;
00271 do
00272 push_heap( heap.begin(), it, HeapOrder() );
00273 while ( it++ != heap.end() );
00274 }
00275
00277 template<class T> class KDBuilder: public KDTree<T> {
00278 public:
00279 typedef KDTree<T> Tree;
00280 typedef T BoundsPair[2];
00281 typedef typename Tree::Bounds Bounds;
00283 typedef int (KDBuilder::*CoordChooser)
00284 (int nodeIndex,int *beginIDs,int *endIDs,int depthLeft) const;
00285 protected:
00286 using Tree::depth; using Tree::length; using Tree::count;
00287 using Tree::nodes; using Tree::dataIDs; using Tree::bounds;
00288
00289 const T *data;
00290 const CoordChooser chooser;
00291 mutable Bounds chooserTmp;
00292
00293 KDBuilder(const T *data_,int length,int count,CoordChooser chooser_)
00294 : Tree(length,count), data(data_), chooser(chooser_), chooserTmp(0) {
00295 ASSERT( length>0 && count>0 && chooser && data );
00296
00297 for (int i=0; i<count; ++i)
00298 dataIDs[i]= i;
00299 getBounds(bounds);
00300 if (count>1)
00301 buildNode(1,dataIDs,dataIDs+count,depth);
00302 delete[] chooserTmp;
00303 DEBUG_ONLY( chooserTmp= 0; data= 0; )
00304 }
00305
00307 struct NewBounds {
00308 void operator()(const T &val,BoundsPair &bounds) const
00309 { bounds[0]= bounds[1]= val; }
00310 };
00312 struct BoundsExpander {
00313 void operator()(const T &val,BoundsPair &bounds) const {
00314 if ( val < bounds[0] )
00315 bounds[0]= val; else
00316 if ( val > bounds[1] )
00317 bounds[1]= val;
00318 }
00319 };
00322 void getBounds(Bounds boundsRes) const {
00323 using namespace FieldMath;
00324 ASSERT(length>0);
00325
00326 transform2(data,data+length,boundsRes,NewBounds());
00327 int count= Tree::count;
00328 const T *nowData= data;
00329
00330 while (--count) {
00331 nowData+= length;
00332 transform2( nowData, nowData+length, boundsRes, BoundsExpander() );
00333 }
00334 }
00337 void getBounds(const int *beginIDs,const int *endIDs,Bounds boundsRes) const {
00338 using namespace FieldMath;
00339 ASSERT(endIDs>beginIDs);
00340
00341 const T *begin= data + *beginIDs*length;
00342 transform2(begin,begin+length,boundsRes,NewBounds());
00343
00344 while ( ++beginIDs != endIDs ) {
00345 begin= data + *beginIDs*length;
00346 transform2( begin, begin+length, boundsRes, BoundsExpander() );
00347 }
00348 }
00349
00352 void buildNode(int nodeIndex,int *beginIDs,int *endIDs,int depthLeft);
00353
00354 public:
00357 static Tree* makeTree(const T *data,int length,int count,CoordChooser chooser) {
00358 KDBuilder builder(data,length,count,chooser);
00359
00360 return new Tree(builder);
00361 }
00362
00364 int choosePrecise(int nodeIndex,int *beginIDs,int *endIDs,int ) const;
00366 int chooseFast(int ,int* ,int* ,int depthLeft) const
00367 { return depthLeft%length; }
00369 int chooseRand(int ,int* ,int* ,int ) const
00370 { return rand()%length; }
00373 int chooseApprox(int nodeIndex,int* ,int* ,int depthLeft) const;
00374 };
00375
00376
00377
00378 namespace NOSPACE {
00380 template<class T> struct MaxDiffCoord {
00381 typedef typename KDBuilder<T>::BoundsPair BoundsPair;
00382
00383 T maxDiff;
00384 int bestIndex
00385 , nextIndex;
00386
00388 MaxDiffCoord(const BoundsPair& bounds0)
00389 : maxDiff(bounds0[1]-bounds0[0]), bestIndex(0), nextIndex(1) {}
00390
00392 void operator()(const BoundsPair& bounds_i) {
00393 T diff= bounds_i[1]-bounds_i[0];
00394 if (diff>maxDiff) {
00395 bestIndex= nextIndex;
00396 maxDiff= diff;
00397 }
00398 ++nextIndex;
00399 }
00400 };
00401 }
00402 template<class T> int KDBuilder<T>
00403 ::choosePrecise(int nodeIndex,int *beginIDs,int *endIDs,int) const {
00404 ASSERT( nodeIndex>0 && beginIDs && endIDs && beginIDs<endIDs );
00405
00406 BoundsPair boundsStorage[length];
00407 const BoundsPair *localBounds;
00408
00409 if ( nodeIndex>1 ) {
00410 localBounds= boundsStorage;
00411 getBounds( beginIDs, endIDs, boundsStorage );
00412 } else
00413 localBounds= this->bounds;
00414
00415 MaxDiffCoord<T> mdc= for_each
00416 ( localBounds+1, localBounds+length, MaxDiffCoord<T>(localBounds[0]) );
00417 ASSERT( mdc.nextIndex == length );
00418 return mdc.bestIndex;
00419 }
00420
00421 template<class T> int KDBuilder<T>::chooseApprox(int nodeIndex,int*,int*,int) const {
00422 using namespace FieldMath;
00423 ASSERT(nodeIndex>0);
00424
00425 int myDepth= log2ceil(nodeIndex+1)-1;
00426 if (!myDepth) {
00427 ASSERT(nodeIndex==1);
00428 chooserTmp= new BoundsPair[length*(depth+1)];
00429 assign( bounds, length, chooserTmp );
00430 }
00431
00432 Bounds myBounds= chooserTmp+length*myDepth;
00433 if (myDepth) {
00434 const typename Tree::Node &parent= nodes[nodeIndex/2];
00435 Bounds parentBounds= myBounds-length;
00436 if (nodeIndex%2) {
00437 assign( parentBounds, length, myBounds );
00438 myBounds[parent.coord][0]= parent.threshold;
00439 } else
00440 if ( nodeIndex+1 < count ) {
00441 myBounds[parent.coord][0]= parentBounds[parent.coord][0];
00442 myBounds[parent.coord][1]= parent.threshold;
00443 } else {
00444 ASSERT( nodeIndex+1 == count );
00445 assign( parentBounds, length, myBounds );
00446 myBounds[parent.coord][1]= parent.threshold;
00447 }
00448 }
00449
00450 MaxDiffCoord<T> mdc= for_each( myBounds+1, myBounds+length, MaxDiffCoord<T>(myBounds[0]) );
00451 ASSERT( mdc.nextIndex == length );
00452 return mdc.bestIndex;
00453 }
00454
00455 namespace NOSPACE {
00457 template<class T> class IndexComparator {
00458 const T *data;
00459 int length;
00460 public:
00461 IndexComparator(const T *data_,int length_,int index_)
00462 : data(data_+index_), length(length_) {}
00463 bool operator()(int a,int b) const
00464 { return data[a*length] < data[b*length]; }
00465 };
00466 }
00467 template<class T> void KDBuilder<T>
00468 ::buildNode(int nodeIndex,int *beginIDs,int *endIDs,int depthLeft) {
00469 int count= endIDs-beginIDs;
00470
00471 ASSERT( count>=2 && powers[depthLeft-1]<count && count<=powers[depthLeft] );
00472 --depthLeft;
00473
00474 bool shallowRight= ( count <= powers[depthLeft]+powers[depthLeft-1] );
00475 int *middle= shallowRight
00476 ? endIDs-powers[depthLeft-1]
00477 : beginIDs+powers[depthLeft];
00478
00479 int coord= (this->*chooser)(nodeIndex,beginIDs,endIDs,depthLeft);
00480 nth_element( beginIDs, middle , endIDs, IndexComparator<T>(data,length,coord) );
00481
00482 nodes[nodeIndex].coord= coord;
00483 nodes[nodeIndex].threshold= data[*middle*length+coord];
00484
00485 switch (count) {
00486 default:
00487
00488 buildNode( 2*nodeIndex+1, middle, endIDs, depthLeft-shallowRight );
00489 case 3:
00490
00491 buildNode( 2*nodeIndex, beginIDs, middle, depthLeft );
00492 case 2:
00493 ;
00494 }
00495 }
00496
00497 #endif // KDTREE_HEADER_