雰囲気データサイエンティストの備忘録

Atmosphere Data Scientist's Memorandum


ndarrayのmutable/imutableで混乱した話
Python

概要

初歩的ですが忘れないようにメモに残します。

コード

import numpy as np

def function1(x):
    x = x + 5

def function2(x):
    x += 5

def main():
    # case1
    x = np.zeros(3)
    print(f"before: {x}")
    function1(x)
    print(f"after: {x}")

    # case2
    x = np.zeros(3)
    print(f"before: {x}")
    function2(x)
    print(f"after: {x}")

if __name__ == "__main__":
    main()
before: [0. 0. 0.]
after: [0. 0. 0.]

before: [0. 0. 0.]
after: [5. 5. 5.]

case1とcase2でprintの出力が違います。

case1ではfunction1内で処理された変数xmain関数内には残らないのに対し,case2ではfunction2内の処理が残り,xがglobal変数のような振る舞いをしています。

原因

x = np.zeros(3)で最初に変数xを定義するとき,変数xはmutableなオブジェクトとして生成されます。
function1x += 5ではこのxを直接的に変更することになり変数のidが維持されますが,function2x = x + 5xを再定義することになります。

つまり変数xは,mainfunction1では同一のもの,mainfunction2では別物として取り扱われます。

以下のように変数x内の要素への代入でもfunction2と同様の挙動をします。

# case3
def function3(x):
    for i, xx in enumerate(x):
        x[i] = xx + i**2

x = np.zeros(3)
print(f"before: {x}")
function3(x)
print(f"after: {x}")
before: [0. 0. 0.]
after: [0. 1. 4.]

まとめ

関数内の変数はreturnしない限り使い捨てだと思っていたので,混乱しました…
意図せず変数がすり替わってしまうこともあるので,基本的には引数のmutableを前提にした処理は書かないほうが無難な気がします。