PythonでA*(A-Star)アルゴリズム

今回はA*アルゴリズムPythonでやってみます。
ゲームプログラマの間では、もはや常識となりつつある最短経路問題解決アルゴリズムです。


A*は、古典的手法であるダイクストラ法」を改良したものです。
スタート地点からノードnを通ってゴールに辿り付くとき、最短距離をf(n)とすると、


f(n) = g(n) + h(n)


とすることができます、g(n)は「スタートからノードnまでの最短距離」、h(n)は「ノードnからゴールまでの最短距離」です。
でも、最初から適切なg(n)とh(n)が判ってるなら苦労しませんよね。


だから、テキトーな予測値を使って、最短経路をある程度予測して効率的に経路探索をしてみようという事です。
テキトーな予測値を使った最短経路距離をf*(n)とすると


f*(n) = g*(n) + h*(n)


となります、f*(n)を求めるためにテキトーなg*(n)とh*(n)を求める訳です。
g*(n)は「スタートからノードnまでの最短距離予測」です、今まで辿ってきた経路を覚えていれば何とかなりそうです。
問題はh*(n)で「ノードnからゴールまでの最短距離予測」です。


A*が最短距離を保証するためにはh*(n)は、「0<= h*(n) <= h(n)」の範囲である必要があります。
つまり予測ではない本物のh(n)よりも、小さい値である必要があります。
ということですので、h*(n)には「直線距離(ユークリッド距離)」を使う場合が多いです。


余談ですがh*(n)が常に0の場合、「ダイクストラ法」と等しくなります、つまり、A*はh*(n)によってテキトーだがそこそこ近いヒント(これをヒューリスティック関数と呼ぶ)を与える事により、効率的に経路検索をしているわけです。


小難しい説明よりも、実際に動くものを見た方が早そうですね。
以下、PythonでA*を実装してみました。

#!/usr/local/bin/python
# -*- coding: utf-8 -*-
map_data = [
'OOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO',
'OS  O     O     O         O          O',
'O   O  O  O  O  O         O    OOOO GO',
'O      O     O  O   OOOO  O    O  OOOO',
'OO OOOOOOOOOOOOOOO  O     O    O     O',
'O                O  O     O          O',
'O        OOO     O  O     OOOOOOOOO  O',
'O  OO    O    OOOO  O     O      OO  O',
'O   O    O          O     O  O   O   O',
'O   OOO  O          O        O   O   O',
'O        O          O        O       O',
'OOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO',
]
map_width  = max([len(x) for x in map_data])
map_height = len(map_data)

class Node(object):
    """
    f(n) ... startからgoalまでの最短距離
    g(n) ... startからnノードまでの最短距離
    h(n) ... nノードからgoalまでの最短距離
    f(n) = g(n) + h(n)

    関数を推定値にすることにより最短距離を予測する
    h*(n)をnからgoalまでの直線距離と仮定する。

    f*(n) = g*(n) + h*(n)
    """
    start = None    #start位置(x,y)
    goal = None     #goal位置(x,y)

    def __init__(self,x,y):
        self.pos    = (x,y)
        self.hs     = (x-self.goal[0])**2 + (y-self.goal[1])**2
        self.fs     = 0
        self.owner_list  = None
        self.parent_node = None

    def isGoal(self):
        return self.goal == self.pos

class NodeList(list):
    def find(self, x,y):
        l = [t for t in self if t.pos==(x,y)]
        return l[0] if l != [] else None
    def remove(self,node):
        del self[self.index(node)]


#スタート位置とゴール位置を設定
for (i,x) in enumerate(map_data):
    if 'S' in x:
        Node.start = (x.index('S'),i)
    elif 'G' in x:
        Node.goal = (x.index('G'),i)

#OpenリストとCloseリストを設定
open_list     = NodeList()
close_list    = NodeList()
start_node    = Node(*Node.start)
start_node.fs = start_node.hs
open_list.append(start_node)

while True:
    #Openリストが空になったら解なし
    if open_list == []:
        print "There is no route until reaching a goal."
        exit(1);

    #Openリストからf*が最少のノードnを取得
    n = min(open_list, key=lambda x:x.fs)
    open_list.remove(n)
    close_list.append(n)

    #最小ノードがゴールだったら終了
    if n.isGoal():
        end_node = n
        break

    #f*() = g*() + h*() -> g*() = f*() - h*()
    n_gs = n.fs - n.hs

    #ノードnの移動可能方向のノードを調べる
    for v in ((1,0),(-1,0),(0,1),(0,-1)):
        x = n.pos[0] + v[0]
        y = n.pos[1] + v[1]

        #マップが範囲外または壁(O)の場合はcontinue
        if not (0 < y < map_height and
                0 < x < map_width and
                map_data[y][x] != 'O'):
            continue

        #移動先のノードがOpen,Closeのどちらのリストに
        #格納されているか、または新規ノードなのかを調べる
        m = open_list.find(x,y)
        dist = (n.pos[0]-x)**2 + (n.pos[1]-y)**2
        if m:
            #移動先のノードがOpenリストに格納されていた場合、
            #より小さいf*ならばノードmのf*を更新し、親を書き換え
            if m.fs > n_gs + m.hs + dist:
                m.fs = n_gs + m.hs + dist
                m.parent_node = n
        else:
            m = close_list.find(x,y)
            if m:
                #移動先のノードがCloseリストに格納されていた場合、
                #より小さいf*ならばノードmのf*を更新し、親を書き換え
                #かつ、Openリストに移動する
                if m.fs > n_gs + m.hs + dist:
                    m.fs = n_gs + m.hs + dist
                    m.parent_node = n
                    open_list.append(m)
                    close_list.remove(m)
            else:
                #新規ノードならばOpenリストにノードに追加
                m = Node(x,y)
                m.fs = n_gs + m.hs + dist
                m.parent_node = n
                open_list.append(m)


#endノードから親を辿っていくと、最短ルートを示す
n = end_node.parent_node
m = [[x for x in line] for line in map_data]

while True:
    if n.parent_node == None:
        break
    m[n.pos[1]][n.pos[0]] = '+'
    n = n.parent_node
print "\n".join(["".join(x) for x in m])

もうちょっと簡潔に書けるかな・・・?
(2009.7.15 NodeListクラスを作って、少しだけ簡潔にしました)
実行結果は以下の通りです、「S」がスタート位置、「G」がゴール位置、「O」が壁となっています。

OOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO
OS+ O     O     O         O   +++++++O
O + O  O  O  O  O  +++++++O   +OOOO GO
O +    O     O  O  +OOOO +O   +O  OOOO
OO+OOOOOOOOOOOOOOO +O    +O   +O     O
O ++++++++++++   O +O    +O   ++++++ O
O        OOO +   O +O    +OOOOOOOOO+ O
O  OO    O   +OOOO +O    +O +++++OO+ O
O   O    O   +++++++O    +O +O  +O++ O
O   OOO  O          O    ++++O  +O+  O
O        O          O        O  +++  O
OOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO

結構難しい迷路ですけど、ちゃんとやってくれますね。
移動方向は上下左右の4方向のみですが、以下のコードを変更すると斜めを入れた8方向で探索してくれます。
(行が長ぇよ!という方は、移動方向を変数にて定義してください)

#ノードnの移動可能方向のノードを調べる
for v in ((1,0),(-1,0),(0,1),(0,-1)):
    ↓
#ノードnの移動可能方向のノードを調べる
for v in ((1,0),(-1,0),(0,1),(0,-1),(1,1),(-1,1),(1,-1),(-1,-1)):

実行結果は以下の通りです、斜めに進んでくれるので、前よりも効率的に進んでいるのがわかります。

OOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO
OS  O     O     O         O    +++++ O
O + O  O  O  O  O   +++++ O   +OOOO GO
O +    O     O  O  +OOOO +O   +O  OOOO
OO+OOOOOOOOOOOOOOO +O    +O   +O     O
O  ++++++++++    O +O    +O    ++++  O
O        OOO +   O +O    +OOOOOOOOO+ O
O  OO    O   +OOOO+ O    +O ++++ OO+ O
O   O    O    ++++  O    +O+ O  +O + O
O   OOO  O          O     +  O  +O+  O
O        O          O        O   +   O
OOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO

実際にゲームに組み込む時は、もうちょっと工夫が必要かもしれませんね。