DSL_2_C: Range Search (kD Tree) (続き2)
とりあえずやろうと思っていることをなんとかかんとか表現してみた
速くなっているのかはわからない
class Area():
MAX = 1000000000
MIN = -MAX
def __init__(self):
self.sx = self.sy = Area.MIN
self.tx = self.ty = Area.MAX
self.indices = None
self.lb_area = self.rt_area = None
def read_points():
n = int(input())
points = []
for _ in range(n):
x, y = [int(x) for x in input().split()]
points.append((x, y))
return points
# dir = 0 -> 左右に分割
# dir = 1 -> 上下に分割
def find_mid(points, indices, dir):
return points[indices[len(indices) // 2]][dir]
def partition(points, mid, indices, dir):
lb_indices = []
rt_indices = []
for i in indices:
if points[i][dir] < mid:
lb_indices.append(i)
else:
rt_indices.append(i)
return lb_indices, rt_indices
def divide_area(points, area, dir):
if len(area.indices) <= 1:
return
area.indices.sort(key=lambda x: points[x][dir])
mid = find_mid(points, area.indices, dir)
lb_indices, rt_indices = partition(points, mid, area.indices, dir)
lb_area = rt_area = None
if lb_indices:
lb_area = copy.copy(area)
if dir == 0:
lb_area.tx = mid
else:
lb_area.ty = mid
lb_area.indices = lb_indices
divide_area(points, lb_area, 1-dir)
if rt_indices:
rt_area = copy.copy(area)
if dir == 0:
rt_area.sx = mid
else:
rt_area.sy = mid
rt_area.indices = rt_indices
divide_area(points, rt_area, 1-dir)
area.lb_area = lb_area
area.rt_area = rt_area
def make_area_tree(points):
tree = Area()
tree.indices = list(range(len(points)))
divide_area(points, tree, 0)
return tree
def find_points(sx, tx, sy, ty, points, tree):
found = []
def rec(area):
nonlocal points, found
if tx < area.sx or area.tx < sx or \
ty < area.sy or area.ty < sy:
return
if sx < area.sx and area.tx < tx and \
sy < area.sy and area.ty < ty:
found += area.indices
return
if area.lb_area is None and area.rt_area is None:
if area.indices:
i = area.indices[0]
p = points[i]
if sx < p[0] < tx and sy < p[1] < ty:
found.append(i)
return
if area.lb_area is not None:
rec(area.lb_area)
if area.rt_area is not None:
rec(area.rt_area)
rec(tree)
found.sort()
return found
def process_queries(points, tree):
q = int(input())
for _ in range(q):
sx, tx, sy, ty = [int(x) for x in input().split()]
found = find_points(sx - .5, tx + .5, sy - .5, ty + .5, points, tree)
print(*found, sep="\n")
if found:
print()
def main():
points = read_points()
tree = make_area_tree(points)
process_queries(points, tree)
if __name__ == '__main__':
main()
Case #13でTLEだからテストふたつ分速くなった!
それよりちゃんと答えが出ていることが驚き
(改行の数とか出力がソートされてなかったとかは直したけど)