사용자 도구

사이트 도구


kb:intervaltree

Interval Tree

from http://discuss.joelonsoftware.com/default.asp?joel.3.10162.13

rmap = RangedMap()
rmap.addRange(0, 10)
rmap.addRange(11, 35)
 
rmap[3] = 'some value'
rmap[6] = 'another value'
rmap[11] = 'foo'
rmap[50] = 'blah'
 
print rmap[8]  # prints {3: 'some value', 6: 'another value'}
print rmap[33]  # prints {11: 'foo'}
print rmap[50]  # prints 'blah'

바로 이런 것과 비슷한 일을 하기 위한 트리. KdTree랑 비슷한 것 같기도 하면서도 미묘하게 틀린데…

C++ Implementation

Emin Martinian 씨의 홈페이지에서 가져온 소스를 템플릿 버전으로 변경하고, 구조를 좀 변경했다.

IntervalTree.h
////////////////////////////////////////////////////////////////////////////////
/// \file IntervalTree.h
/// \author excel96
/// \date 2006.5.2
///
/// http://www.csua.berkeley.edu/~emin/source_code/cpp_trees/index.html
/// 위의 주소에서 가져온 소스를 템플릿 버전으로 변경하고, 구조를 좀 변경했다.
////////////////////////////////////////////////////////////////////////////////
 
#ifndef __INTERVALTREE_H__
#define __INTERVALTREE_H__
 
#include <limits>
#include <vector>
#include <stack>
 
template <typename T>
struct ptr_delete_policy { void operator()(T ptr) { delete ptr; } };
 
template <typename T>
struct array_delete_policy { void operator()(T ptr) { delete [] ptr; } };
 
template <typename T>
struct no_delete_policy { void operator()(T ptr) {} };
 
////////////////////////////////////////////////////////////////////////////////
/// \class cIntervalTree
/// \brief "Introduction To Algorithms" by Cormen, Leisserson, and Rivest 
/// 이라는 책에 나오는 Interval Tree의 템플릿 구현이다.
///
/// 기본적인 사용법은 다음과 같다.
/// <pre>
/// typedef cIntervalTree<float, std::string> MyTree;
/// 
/// MyTree tree(-FLT_MAX, FLT_MAX);
/// tree.Insert(-5, -1, std::string("111"));
/// tree.Insert(11, 20, std::string("222"));
/// tree.Insert(21, 30, std::string("333"));
/// 
/// MyTree::RESULTS s;
/// tree.Enumerate(-2, -2, s);
/// for (size_t i=0; i<s.size(); ++i)
/// {
///     MyTree::cInterval* interval = s[i];
///     interval = interval;
/// }
/// </pre>
///
/// 구역이 서로 겹치지 않는 경우, Enumerate 함수의 변종을 사용하면 된다. 
/// <pre>
/// MyTree::cInterval* interval = Enumerate(-2, -2);
/// </pre>
////////////////////////////////////////////////////////////////////////////////
 
template <typename T, typename I, template <typename> class DeletionPolicy=no_delete_policy>
class cIntervalTree 
{
public:
    class cInterval
    {
    public:
        T LowPoint;
        T HighPoint;
        I Item;
 
        cInterval(T l, T h, I i) : LowPoint(l), HighPoint(h), Item(i) {}
        virtual ~cInterval() { DeletionPolicy<I> doDelete; doDelete(Item); }
    };
 
    typedef std::vector<cInterval*> RESULTS;
 
    struct cNode 
    {
    public:
        cInterval* Stored;
        T Key;
        T High;
        T MaxHigh;
        bool Red; 
        cNode* Left;
        cNode* Right;
        cNode* Parent;
 
        cNode() 
            : Stored(NULL), Red(false), Left(NULL), Right(NULL), Parent(NULL) 
        {
        }
 
        cNode(cInterval* interval) 
            : Stored(interval), Key(interval->LowPoint),  
            High(interval->HighPoint), MaxHigh(interval->HighPoint) 
        {
        }
    };
 
 
private:
    class cRecursionNode 
    {
    public:
        cNode* StartNode;
        size_t ParentIndex;
        bool   TryRightBranch;
 
        cRecursionNode(cNode* n=NULL, size_t p=0, bool right=false)
            : StartNode(n), ParentIndex(p), TryRightBranch(right)
        {
        }
    };
 
    typedef std::vector<cRecursionNode> RECURSION_STACK;
 
    const T MIN_VALUE;
    const T MAX_VALUE;
    cNode* m_Root;
    cNode* m_Nil;
 
    mutable RECURSION_STACK m_RecursionStack;
 
 
public:
    cIntervalTree(T minValue=std::numeric_limits<T>::min(), T maxValue=std::numeric_limits<T>::max())
    : m_Nil(new cNode), m_Root(new cNode), MIN_VALUE(minValue), MAX_VALUE(maxValue)
    {
        Assert(minValue < maxValue);
 
        m_Nil->Left = m_Nil->Right = m_Nil->Parent = m_Nil;
        m_Nil->Red = 0;
        m_Nil->Key = m_Nil->High = m_Nil->MaxHigh = MIN_VALUE;
        m_Nil->Stored = NULL;
 
        m_Root->Parent = m_Root->Left = m_Root->Right = m_Nil;
        m_Root->Key = m_Root->High = m_Root->MaxHigh = MAX_VALUE;
        m_Root->Red = 0;
        m_Root->Stored = NULL;
 
        m_RecursionStack.push_back(cRecursionNode(NULL, 0, false));
    }
 
    ~cIntervalTree()
    {
        cNode* x = m_Root->Left;
        std::stack<cNode*> stuffToFree;
 
        if (x != m_Nil)
        {
            if (x->Left != m_Nil) stuffToFree.push(x->Left);
            if (x->Right != m_Nil) stuffToFree.push(x->Right);
            delete x;
            while (!stuffToFree.empty())
            {
                x = stuffToFree.top();
                stuffToFree.pop();
                if (x->Left != m_Nil) stuffToFree.push(x->Left);
                if (x->Right != m_Nil) stuffToFree.push(x->Right);
                delete x;
            }
        }
 
        delete m_Nil;
        delete m_Root;
 
        m_RecursionStack.clear();
    }
 
 
public:
    cNode* Insert(T low, T high, I item)
    {
        cNode* y = NULL;
        cNode* x = new cNode(new cInterval(low, high, item));
        cNode* newNode = NULL;
 
        TreeInsertHelp(x);
        FixUpMaxHigh(x->Parent);
        newNode = x;
        x->Red = 1;
        while(x->Parent->Red) 
        { 
            if (x->Parent == x->Parent->Parent->Left)
            {
                y = x->Parent->Parent->Right;
                if (y->Red)
                {
                    x->Parent->Red = 0;
                    y->Red = 0;
                    x->Parent->Parent->Red = 1;
                    x = x->Parent->Parent;
                }
                else
                {
                    if (x == x->Parent->Right)
                    {
                        x = x->Parent;
                        LeftRotate(x);
                    }
                    x->Parent->Red = 0;
                    x->Parent->Parent->Red = 1;
                    RightRotate(x->Parent->Parent);
                }
            }
            else 
            {
                y = x->Parent->Parent->Left;
                if (y->Red)
                {
                    x->Parent->Red = 0;
                    y->Red = 0;
                    x->Parent->Parent->Red = 1;
                    x = x->Parent->Parent;
                }
                else
                {
                    if (x == x->Parent->Left)
                    {
                        x = x->Parent;
                        RightRotate(x);
                    }
                    x->Parent->Red = 0;
                    x->Parent->Parent->Red = 1;
                    LeftRotate(x->Parent->Parent);
                }
            }
        }
 
        m_Root->Left->Red = 0;
        return newNode;
    }
 
    cInterval* Delete(cNode* z)
    {
        cNode* y = ((z->Left == m_Nil) || (z->Right == m_Nil)) ? z : GetSuccessorOf(z);
        cNode* x = (y->Left == m_Nil) ? y->Right : y->Left;
        cInterval* returnValue = z->Stored;
 
        if (m_Root == (x->Parent = y->Parent)) 
        {
            m_Root->Left = x;
        }
        else
        {
            if (y == y->Parent->Left)
                y->Parent->Left = x;
            else
                y->Parent->Right = x;
        }
 
        if (y != z) 
        {
            Assert(y != m_Nil && "y is nil");
 
            y->MaxHigh = MIN_VALUE;
            y->Left = z->Left;
            y->Right = z->Right;
            y->Parent = z->Parent;
            z->Left->Parent = z->Right->Parent = y;
 
            if (z == z->Parent->Left)
                z->Parent->Left = y;
            else
                z->Parent->Right = y;
 
            FixUpMaxHigh(x->Parent);
            if (!(y->Red))
            {
                y->Red = z->Red;
                DeleteFixUp(x);
            }
            else
                y->Red = z->Red;
 
            delete z;
 
#ifdef _DEBUG
            CheckAssumptions();
            Assert(!m_Nil->Red && "nil is not black");
            Assert(m_Nil->MaxHigh == MIN_VALUE && "nil->MaxHigh != MIN_VALUE");
#endif
        }
        else
        {
            FixUpMaxHigh(x->Parent);
            if (!(y->Red)) DeleteFixUp(x);
            delete y;
#ifdef _DEBUG
            CheckAssumptions();
            Assert(!m_Nil->Red && "nil is not black");
            Assert(m_Nil->MaxHigh == MIN_VALUE && "nil->MaxHigh != MIN_VALUE");
#endif
        }
 
        return returnValue;
    }
 
    cNode* GetPredecessorOf(cNode* x) const
    {
        cNode* y = NULL;
 
        if (m_Nil != (y = x->Left)) 
        {
            while (y->Right != m_Nil) 
            {
                y = y->Right;
            }
 
            return y;
        }
        else
        {
            y = x->Parent;
            while (x == y->Left)
            {
                if (y == m_Root) return m_Nil;
                x = y;
                y = y->Parent;
            }
 
            return y;
        }
    }
 
    cNode* GetSuccessorOf(cNode* x) const
    {
        cNode* y = NULL;
 
        if (m_Nil != (y = x->Right)) 
        {
            while(y->Left != m_Nil) 
            {
                y = y->Left;
            }
 
            return y;
        }
        else
        {
            y = x->Parent;
            while(x == y->Right) 
            {
                x = y;
                y = y->Parent;
            }
            if (y == m_Root) return m_Nil;
            return y;
        }
    }
 
    bool Enumerate(T low, T high, RESULTS& results) const
    {
        size_t currentParent = 0;
        size_t stackTop = 1;
 
        cNode* x = m_Root->Left;
        bool stuffToDo = (x != m_Nil);
        while (stuffToDo)
        {
            struct 
            {
                bool operator()(const T& a1, const T& a2, const T& b1, const T& b2)
                {
                    if (a1 <= b1) return b1 <= a2;
                    else          return a1 <= b2;
                }
            } OVERLAP;
 
            if (OVERLAP(low, high, x->Key, x->High))
            {
                results.push_back(x->Stored);
                m_RecursionStack[currentParent].TryRightBranch = true;
            }
 
            if (x->Left->MaxHigh >= low) 
            {
                if (stackTop == m_RecursionStack.size())
                {
                    m_RecursionStack.push_back(cRecursionNode(x, currentParent, false));
                }
                else
                {
                    m_RecursionStack[stackTop].StartNode = x;
                    m_RecursionStack[stackTop].ParentIndex = currentParent;
                    m_RecursionStack[stackTop].TryRightBranch = false;
                }
 
                currentParent = stackTop++;
                x = x->Left;
            }
            else
            {
                x = x->Right;
            }
 
            stuffToDo = (x != m_Nil);
            while (!stuffToDo && stackTop > 1)
            {
                if (m_RecursionStack[--stackTop].TryRightBranch)
                {
                    x = m_RecursionStack[stackTop].StartNode->Right;
                    currentParent = m_RecursionStack[stackTop].ParentIndex;
                    m_RecursionStack[currentParent].TryRightBranch = true;
                    stuffToDo = (x != m_Nil);
                }
            }
        }
 
        Assert(stackTop == 1 &&
            "recursion stack not empty when exiting Enumerate()");
 
        return !results.empty();
    }
 
    cInterval* Enumerate(T low, T high) const
    {
        size_t currentParent = 0;
        size_t stackTop = 1;
 
        cNode* x = m_Root->Left;
        bool stuffToDo = (x != m_Nil);
        while (stuffToDo)
        {
            struct 
            {
                bool operator()(const T& a1, const T& a2, const T& b1, const T& b2)
                {
                    if (a1 <= b1) return b1 <= a2;
                    else          return a1 <= b2;
                }
            } OVERLAP;
 
            if (OVERLAP(low, high, x->Key, x->High))
                return x->Stored;
 
            if (x->Left->MaxHigh >= low) 
            {
                if (stackTop == m_RecursionStack.size())
                {
                    m_RecursionStack.push_back(cRecursionNode(x, currentParent, false));
                }
                else
                {
                    m_RecursionStack[stackTop].StartNode = x;
                    m_RecursionStack[stackTop].ParentIndex = currentParent;
                    m_RecursionStack[stackTop].StartNode = false;
                }
 
                currentParent = stackTop++;
                x = x->Left;
            }
            else
            {
                x = x->Right;
            }
 
            stuffToDo = (x != m_Nil);
            while (!stuffToDo && stackTop > 1)
            {
                if (m_RecursionStack[--stackTop].TryRightBranch)
                {
                    x = m_RecursionStack[stackTop].StartNode->Right;
                    currentParent = m_RecursionStack[stackTop].ParentIndex;
                    m_RecursionStack[currentParent].TryRightBranch = true;
                    stuffToDo = (x != m_Nil);
                }
            }
        }
 
        Assert(stackTop == 1 &&
            "recursion stack not empty when exiting Enumerate()");
 
        return NULL;
    }
 
    void CheckAssumptions() const
    {
        Verify(m_Nil->Key == MIN_VALUE);
        Verify(m_Nil->High == MIN_VALUE);
        Verify(m_Nil->MaxHigh == MIN_VALUE);
        Verify(m_Root->Key == MAX_VALUE);
        Verify(m_Root->High == MAX_VALUE);
        Verify(m_Root->MaxHigh == MAX_VALUE);
        Verify(m_Nil->Stored == NULL);
        Verify(m_Root->Stored == NULL);
        Verify(m_Nil->Red == 0);
        Verify(m_Root->Red == 0);
        CheckMaxHighFields(m_Root->Left);
    }
 
 
private:
    void LeftRotate(cNode* x)
    {
        cNode* y = x->Right;
        x->Right = y->Left;
 
        if (y->Left != m_Nil) 
            y->Left->Parent = x; 
 
        y->Parent = x->Parent;  
 
        if (x == x->Parent->Left)
        {
            x->Parent->Left = y;
        }
        else
        {
            x->Parent->Right = y;
        }
        y->Left = x;
        x->Parent = y;
 
        x->MaxHigh = std::max(x->Left->MaxHigh, std::max(x->Right->MaxHigh, x->High));
        y->MaxHigh = std::max(x->MaxHigh, std::max(y->Right->MaxHigh, y->High));
 
#ifdef _DEBUG
        CheckAssumptions();
        Assert(!m_Nil->Red && "nil is not red");
        Assert(m_Nil->MaxHigh == MIN_VALUE && "nil->MaxHigh != MIN_VALUE");
#endif
    }
 
    void RightRotate(cNode* y)
    {
        cNode* x = y->Left;
        y->Left = x->Right;
 
        if (m_Nil != x->Right) x->Right->Parent = y; 
 
        x->Parent = y->Parent;
        if (y == y->Parent->Left)
        {
            y->Parent->Left = x;
        }
        else
        {
            y->Parent->Right = x;
        }
 
        x->Right = y;
        y->Parent = x;
 
        y->MaxHigh = std::max(y->Left->MaxHigh, std::max(y->Right->MaxHigh, y->High));
        x->MaxHigh = std::max(x->Left->MaxHigh, std::max(y->MaxHigh, x->High));
 
#ifdef _DEBUG
        CheckAssumptions();
        Assert(!m_Nil->Red && "nil is not red");
        Assert(m_Nil->MaxHigh == MIN_VALUE && "nil->MaxHigh != MIN_VALUE");
#endif
    }
 
    void TreeInsertHelp(cNode* z)
    {
        cNode* y = m_Root;
        cNode* x = m_Root->Left;
 
        z->Left = z->Right = m_Nil;
 
        while (x != m_Nil)
        {
            y = x;
            if ( x->Key > z->Key)
                x = x->Left;
            else
                x = x->Right;
        }
 
        z->Parent = y;
 
        if (y == m_Root || y->Key > z->Key)
            y->Left = z;
        else
            y->Right = z;
 
        Assert(!m_Nil->Red && "nil is not red");
        Assert(m_Nil->MaxHigh == MIN_VALUE && "nil->MaxHigh != MIN_VALUE");
    }
 
    void FixUpMaxHigh(cNode* x)
    {
        while(x != m_Root)
        {
            x->MaxHigh = std::max(x->High, std::max(x->Left->MaxHigh, x->Right->MaxHigh));
            x = x->Parent;
        }
 
#ifdef _DEBUG
        CheckAssumptions();
#endif
    }
 
    void DeleteFixUp(cNode* x)
    {
        cNode* w = NULL;
        cNode* rootLeft = m_Root->Left;
 
        while ((!x->Red) && (rootLeft != x))
        {
            if (x == x->Parent->Left)
            {
                w = x->Parent->Right;
                if (w->Red)
                {
                    w->Red = 0;
                    x->Parent->Red = 1;
                    LeftRotate(x->Parent);
                    w = x->Parent->Right;
                }
 
                if ((!w->Right->Red) && (!w->Left->Red))
                {
                    w->Red = 1;
                    x = x->Parent;
                }
                else
                {
                    if (!w->Right->Red)
                    {
                        w->Left->Red = 0;
                        w->Red = 1;
                        RightRotate(w);
                        w = x->Parent->Right;
                    }
                    w->Red = x->Parent->Red;
                    x->Parent->Red = 0;
                    w->Right->Red = 0;
                    LeftRotate(x->Parent);
                    x = rootLeft; 
                }
            }
            else 
            {
                w = x->Parent->Left;
                if (w->Red)
                {
                    w->Red = 0;
                    x->Parent->Red = 1;
                    RightRotate(x->Parent);
                    w = x->Parent->Left;
                }
                if ((!w->Right->Red) && (!w->Left->Red))
                {
                    w->Red = 1;
                    x = x->Parent;
                }
                else
                {
                    if (!w->Left->Red)
                    {
                        w->Right->Red = 0;
                        w->Red = 1;
                        LeftRotate(w);
                        w = x->Parent->Left;
                    }
                    w->Red = x->Parent->Red;
                    x->Parent->Red = 0;
                    w->Left->Red = 0;
                    RightRotate(x->Parent);
                    x = rootLeft; 
                }
            }
        }
        x->Red = 0;
 
#ifdef _DEBUG
        CheckAssumptions();
        Assert(!m_Nil->Red && "nil is not black");
        Assert(m_Nil->MaxHigh == MIN_VALUE && "nil->MaxHigh != MIN_VALUE");
#endif
    }
 
    void CheckMaxHighFields(cNode* x) const
    {
        if (x != m_Nil)
        {
            CheckMaxHighFields(x->Left);
            if (CheckMaxHighFieldsHelper(x,x->MaxHigh,0) <= 0)
                Assert(false && "error found in CheckMaxHighFields");
            CheckMaxHighFields(x->Right);
        }
    }
 
    int CheckMaxHighFieldsHelper(cNode* y, const T currentHigh, int match) const
    {
        if (y != m_Nil)
        {
            match = CheckMaxHighFieldsHelper(y->Left,currentHigh,match) ? 1 : match;
            Verify(y->High <= currentHigh);
            if (y->High == currentHigh) match = 1;
            match = CheckMaxHighFieldsHelper(y->Right,currentHigh,match) ? 1 : match;
        }
 
        return match;
    }
};
 
#endif
sample.cpp
typedef cIntervalTree<float, std::string> MyTree;
 
MyTree tree(-FLT_MAX, FLT_MAX);
tree.Insert(-5, -1, std::string("111"));
tree.Insert(11, 20, std::string("222"));
tree.Insert(21, 30, std::string("333"));
 
MyTree::RESULTS s;
tree.Enumerate(-2, -2, s);
for (size_t i=0; i<s.size(); ++i)
{
    MyTree::INTERVAL* interval = s[i];
    cout << interval->Item << std::endl;
}

구역이 서로 겹치지 않는 경우, Enumerate 함수의 변종을 사용하면 된다.

MyTree::INTERVAL* interval = Enumerate(-2, -2);

링크

kb/intervaltree.txt · 마지막으로 수정됨: 2014/11/10 19:52 (바깥 편집)