kb84tkhrのブログ

何を書こうか考え中です あ、あと組織とは関係ないってやつです 個人的なやつ

ALDS1_9_C: Priority Queue (続き3)

他の人のソースでも見てみるか

見てみたらみんなheapq使ってました
自分で二分ヒープ書いて通すのは無理なのかな

heapqを使ってるコードでも
stdin.readline()やstdin.readlines()使ってるコードは3秒くらいなのに対し
input()を使ってるコードは8秒くらいで
やっぱりstdinの方が速い模様

さてheapqはあとで使ってみるとしてこっちのコードをどうするか
優先度付きキューとしてのインタフェースもはっきりしたので
クラスにしてみました
ついでにunittest使ってみたりファイルを分けたり

お作法がまだよくわかりません
pep8のチェックは通してますけど
そういうのはこういう本よりも実際のプロジェクトのソース見たほうが
わかりそうかな

ALDS1_9_C.py

#! /usr/bin/env python3
# -*- coding: utf-8 -*-

from sys import stdin
from priority_queue import NoMoreDataError
from priority_queue import PriorityQueue
from priority_queue import swap


def main():
    Q = PriorityQueue()
    lines = stdin.readlines()

    for line in lines:
        cmd = line.split()
        if cmd[0] == "insert":
            Q.insert(int(cmd[1]))
        elif cmd[0] == "extract":
            print(Q.extract_max())
        else:
            break


if __name__ == "__main__":
    main()

priority_queue.py

#! /usr/bin/env python3
# -*- coding: utf-8 -*-


def swap(A, i, j):
    tmp = A[i]
    A[i] = A[j]
    A[j] = tmp


class NoMoreDataError(Exception):
    pass


class PriorityQueue():
    def __init__(self, elems=[]):
        self.A = [None] + elems
        self._build_max_heap()

    def __repr__(self):
        return str(self.A[1:])

    def is_empty(self):
        return len(self.A) == 1

    def insert(self, k):
        self.A.append(k)

        i = self._last()
        if i == 1:
            return

        while i > 1:
            p = self._parent(i)
            if self.A[p] >= self.A[i]:
                break
            swap(self.A, p, i)
            i = p

    def extract_max(self):
        if self._last() == 0:
            raise NoMoreDataError
        if self._last() == 1:
            return self.A.pop()
        else:
            k = self.A[1]
            self.A[1] = self.A.pop()
            self._max_heapify(1)
            return k

    def to_array(self):
        R = []
        while not self.is_empty():
            R.append(self.extract_max())
        return R

    def _parent(self, k):
        return k // 2

    def _left(self, k):
        return 2 * k

    def _right(self, k):
        return 2 * k + 1

    def _last(self):
        return len(self.A) - 1

    def _max_heapify(self, i):
        l = self._left(i)
        r = self._right(i)
        if l <= self._last() and \
           self.A[l] > self.A[i]:
            largest = l
        else:
            largest = i
        if r <= self._last() and self.A[r] > self.A[largest]:
            largest = r

        if largest != i:
            swap(self.A, i, largest)
            self._max_heapify(largest)

    def _build_max_heap(self):
        for i in reversed(range(1, self._last() // 2 + 1)):
            self._max_heapify(i)

priority_queue_test.py

#! /usr/bin/env python3
# -*- coding: utf-8 -*-

import unittest

from random import randrange
from priority_queue import NoMoreDataError
from priority_queue import PriorityQueue
from priority_queue import swap


class TestPriorityQueue(unittest.TestCase):

    def test_sequence1(self):
        Q = PriorityQueue()
        with self.assertRaises(NoMoreDataError):
            Q.extract_max()

    def test_sequence2(self):
        Q = PriorityQueue()
        Q.insert(1)
        Q.insert(2)
        Q.insert(3)
        self.assertEqual(Q.extract_max(), 3)
        self.assertEqual(Q.extract_max(), 2)
        self.assertEqual(Q.extract_max(), 1)
        with self.assertRaises(NoMoreDataError):
            Q.extract_max()

    def test_sequence3(self):
        Q = PriorityQueue([1])
        self.assertEqual(Q.extract_max(), 1)
        with self.assertRaises(NoMoreDataError):
            Q.extract_max()

    def test_sequence4(self):
        Q = PriorityQueue([1, 2, 3])
        self.assertEqual(Q.extract_max(), 3)
        self.assertEqual(Q.extract_max(), 2)
        self.assertEqual(Q.extract_max(), 1)
        with self.assertRaises(NoMoreDataError):
            Q.extract_max()

    def _random_array(self, max, n):
        A = [randrange(max) for _ in range(n)]
        for i in range(n):
            swap(A, i, randrange(n))
        return A

    def _test_build_max_heap(self, A):
        Q = PriorityQueue(A)
        R = Q.to_array()
        self.assertEqual(list(reversed(sorted(A))), R)

    def _test_insert(self, A):
        Q = PriorityQueue()
        for n in A:
            Q.insert(n)
        R = Q.to_array()
        self.assertEqual(list(reversed(sorted(A))), R)

    def test_random_build_max_heap(self):
        for i in range(1000):
            self._test_build_max_heap(self._random_array(15, 10))

    def test_random_insert(self):
        for i in range(1000):
            self._test_insert(self._random_array(15, 10))


if __name__ == "__main__":
    unittest.main()

知ってる範囲でそれっぽく書いてみたつもりですが
まだどう書いていいのやらよくわかりません

このクラスは果たしてPriorityQueueという名前でいいのか
それともBinaryHeapとするべきか
一応、中身が二分ヒープであることは隠蔽して、大きいものから
取り出すというインタフェースだけにしているので
PriorityQueueという名前にしたけれどもまだ迷ってる

parentやleftやrightあたりほんとはスタティックメソッドか
クラスメソッドでいいと思うんだけど
self.って書くのも面倒なのにPriorityQueue.って書くのはつらい
PriorityQueueのためのメソッドだからクラスの外に出すのも
ちょっと気が引ける
短く書ける書き方はないかな

swapはどこに書けばいいんだ
NoMoreDataErrorもちょっと微妙
スタックとかキューとか作ったらそっちでも使いそうだし
それは同じモジュールに入れちゃえばいいのか

とかまだまだいろいろ