kb84tkhrのブログ

何を書こうか考え中です あ、あと組織とは関係ないってやつです 個人的なやつ

DSL_2_C: Range Search (kD Tree) (続き2)

とりあえずやろうと思っていることをなんとかかんとか表現してみた
速くなっているのかはわからない

class Area():
    MAX = 1000000000
    MIN = -MAX

    def __init__(self):
        self.sx = self.sy = Area.MIN
        self.tx = self.ty = Area.MAX
        self.indices = None
        self.lb_area = self.rt_area = None

def read_points():
    n = int(input())
    points = []
    for _ in range(n):
        x, y = [int(x) for x in input().split()]
        points.append((x, y))
    return points

# dir = 0 -> 左右に分割
# dir = 1 -> 上下に分割

def find_mid(points, indices, dir):
    return points[indices[len(indices) // 2]][dir]

def partition(points, mid, indices, dir):
    lb_indices = []
    rt_indices = []
    for i in indices:
        if points[i][dir] < mid:
            lb_indices.append(i)
        else:
            rt_indices.append(i)
    return lb_indices, rt_indices

def divide_area(points, area, dir):

    if len(area.indices) <= 1:
        return

    area.indices.sort(key=lambda x: points[x][dir])
    mid = find_mid(points, area.indices, dir)
    lb_indices, rt_indices = partition(points, mid, area.indices, dir)

    lb_area = rt_area = None
    if lb_indices:
        lb_area = copy.copy(area)
        if dir == 0:
            lb_area.tx = mid
        else:
            lb_area.ty = mid
        lb_area.indices = lb_indices
        divide_area(points, lb_area, 1-dir)
    if rt_indices:
        rt_area = copy.copy(area)
        if dir == 0:
            rt_area.sx = mid
        else:
            rt_area.sy = mid
        rt_area.indices = rt_indices
        divide_area(points, rt_area, 1-dir)
    area.lb_area = lb_area
    area.rt_area = rt_area

def make_area_tree(points):
    tree = Area()
    tree.indices = list(range(len(points)))
    divide_area(points, tree, 0)
    return tree

def find_points(sx, tx, sy, ty, points, tree):

    found = []

    def rec(area):
        nonlocal points, found

        if tx < area.sx or area.tx < sx or \
           ty < area.sy or area.ty < sy:
            return
        if sx < area.sx and area.tx < tx and \
           sy < area.sy and area.ty < ty:
            found += area.indices
            return
        if area.lb_area is None and area.rt_area is None:
            if area.indices:
                i = area.indices[0]
                p = points[i]
                if sx < p[0] < tx and sy < p[1] < ty:
                    found.append(i)
            return
        if area.lb_area is not None:
            rec(area.lb_area)
        if area.rt_area is not None:
            rec(area.rt_area)

    rec(tree)
    found.sort()
    return found

def process_queries(points, tree):
    q = int(input())
    for _ in range(q):
        sx, tx, sy, ty = [int(x) for x in input().split()]
        found = find_points(sx - .5, tx + .5, sy - .5, ty + .5, points, tree)
        print(*found, sep="\n")
        if found:
            print()

def main():
    points = read_points()
    tree = make_area_tree(points)
    process_queries(points, tree)

if __name__ == '__main__':
    main()

Case #13でTLEだからテストふたつ分速くなった!
それよりちゃんと答えが出ていることが驚き
(改行の数とか出力がソートされてなかったとかは直したけど)