Postsndarrayのmutable/imutableで混乱した話
概要
初歩的ですが忘れないようにメモに残します。
コード
import numpy as np
def function1(x):
x = x + 5
def function2(x):
x += 5
def main():
# case1
x = np.zeros(3)
print(f"before: {x}")
function1(x)
print(f"after: {x}")
# case2
x = np.zeros(3)
print(f"before: {x}")
function2(x)
print(f"after: {x}")
if __name__ == "__main__":
main()
before: [0. 0. 0.]
after: [0. 0. 0.]
before: [0. 0. 0.]
after: [5. 5. 5.]
case1とcase2でprintの出力が違います。
case1ではfunction1
内で処理された変数x
がmain
関数内には残らないのに対し,case2ではfunction2
内の処理が残り,x
がglobal変数のような振る舞いをしています。
原因
x = np.zeros(3)
で最初に変数x
を定義するとき,変数x
はmutableなオブジェクトとして生成されます。
function1
のx += 5
ではこのx
を直接的に変更することになり変数のidが維持されますが,function2
のx = x + 5
はx
を再定義することになります。
つまり変数x
は,main
とfunction1
では同一のもの,main
とfunction2
では別物として取り扱われます。
以下のように変数x
内の要素への代入でもfunction2
と同様の挙動をします。
# case3
def function3(x):
for i, xx in enumerate(x):
x[i] = xx + i**2
x = np.zeros(3)
print(f"before: {x}")
function3(x)
print(f"after: {x}")
before: [0. 0. 0.]
after: [0. 1. 4.]
まとめ
関数内の変数はreturnしない限り使い捨てだと思っていたので,混乱しました…
意図せず変数がすり替わってしまうこともあるので,基本的には引数のmutableを前提にした処理は書かないほうが無難な気がします。