kb84tkhrのブログ

何を書こうか考え中です あ、あと組織とは関係ないってやつです 個人的なやつ

DSL_2_C: Range Search (kD Tree) (続き6)

前回のコードは単純にC++Pythonに書き換えただけで
ちょっと何やってるのか飲み込めてないところがあったので、
コードを読み直しつつグローバル変数をなくしたり再帰を整理したり
少しは速くなるはず

from sys import stdin

class Node():
    def __init__(self):
        self.location = None
        self.l = self.r = None

class Point():
    def __init__(self, id=None, x=None, y=None):
        self.id = id
        self.x = x
        self.y = y

    def __lt__(self, other):
        return self.id < other.id

def read_points():
    n = int(stdin.readline())
    P = []
    for i in range(n):
        x, y = [int(x) for x in stdin.readline().split()]
        P.append(Point(i, x, y))
    return P

def make_kd_tree(P):

    def rec(l, r, depth):
        if l >= r:
            return None

        mid = (l + r) // 2

        if depth % 2 == 0:
            P[l:r] = sorted(P[l:r], key=lambda p: p.x)
        else:
            P[l:r] = sorted(P[l:r], key=lambda p: p.y)

        node = Node()
        node.location = mid
        node.l = rec(l, mid, depth + 1)
        node.r = rec(mid + 1, r, depth + 1)

        return node

    root = rec(0, len(P), 0)
    return root

def find(P, root, sx, tx, sy, ty):
    ans = []

    def rec(node, depth):
        nonlocal P, sx, tx, sy, ty, ans

        x = P[node.location].x
        y = P[node.location].y

        if sx <= x <= tx and sy <= y <= ty:
            ans.append(P[node.location])

        if depth % 2 == 0:
            if node.l and sx <= x:
                rec(node.l, depth + 1)
            if node.r and x <= tx:
                rec(node.r, depth + 1)
        else:
            if node.l and sy <= y:
                rec(node.l, depth + 1)
            if node.r and y <= ty:
                rec(node.r, depth + 1)

    rec(root, 0)
    return ans

def process_queries(P, root):
    q = int(stdin.readline())
    for i in range(q):
        sx, tx, sy, ty = [int(x) for x in stdin.readline().split()]
        ans = find(P, root, sx, tx, sy, ty)
        [print(a.id) for a in sorted(ans)]
        print()

def main():
    P = read_points()
    root = make_kd_tree(P)
    process_queries(P, root)

if __name__ == '__main__':
    main()

といってもオーダが下がるわけでもないわけで
Case #12が1.57s (AC)
Case #13が7.15s (TLE)

そんなもんかな
実は、Tをごっそり配列で取ってインデックスでアクセスするのをやめて、
ひとつずつ割り当てて直接アクセスするようにしたところで少し遅くなってる
でもこっちが自然だと思うので自分的にはこれができあがり

他のひとの解答を見てみた
通ってるコードはどうやらみんな違うアルゴリズム
xでソートして二分探索で範囲を絞り、その中でyでソートしてまた二分探索、
っていう感じ
√n個ごとのグループに分けてるところがミソなのかな
ちゃんと分かるほどは読んでない

ところでテキストのNode構造体には、いかにも親ノードを覚えておきそうな
フィールドがあるんだけれども使われてないどころか値を入れられもしてない
とにかく木といえば親を覚えておくもの、って感じでフィールドは作って
しまうのだろうか