はじめに
ジェネレータは値を一度に全て生成せず、必要な時に一つずつ生成する仕組みです。メモリ効率が良く、大量のデータを扱う際に有効です。
イテレータとは
イテレータは__iter__と__next__を持つオブジェクトです。
# リストからイテレータを作成
numbers = [1, 2, 3]
iterator = iter(numbers)
print(next(iterator)) # 1
print(next(iterator)) # 2
print(next(iterator)) # 3
# print(next(iterator)) # StopIteration例外
# for文は内部でイテレータを使用
for num in numbers: # iter(numbers)でイテレータを取得
print(num) # next()で値を取得
ジェネレータ関数
yieldを使う関数はジェネレータ関数になります。
def count_up(n):
"""0からn-1までカウント"""
i = 0
while i < n:
yield i
i += 1
# ジェネレータオブジェクトを取得
gen = count_up(5)
print(type(gen)) # <class 'generator'>
# 値を取得
print(next(gen)) # 0
print(next(gen)) # 1
print(next(gen)) # 2
# for文で使う
for num in count_up(5):
print(num) # 0, 1, 2, 3, 4
yieldの動作
def simple_generator():
print("開始")
yield 1
print("1の後")
yield 2
print("2の後")
yield 3
print("終了")
gen = simple_generator()
print(next(gen)) # "開始" と 1 を出力
print(next(gen)) # "1の後" と 2 を出力
print(next(gen)) # "2の後" と 3 を出力
# next(gen) # "終了" と StopIteration
メモリ効率
import sys
# リスト - 全要素をメモリに保持
def get_squares_list(n):
return [x ** 2 for x in range(n)]
# ジェネレータ - 必要な時に生成
def get_squares_gen(n):
for x in range(n):
yield x ** 2
# メモリ使用量の比較
n = 1000000
lst = get_squares_list(n)
gen = get_squares_gen(n)
print(f"リスト: {sys.getsizeof(lst):,} bytes") # 約8MB
print(f"ジェネレータ: {sys.getsizeof(gen)} bytes") # 約200バイト
実践例
ファイルを1行ずつ読む
def read_lines(filepath):
"""ファイルを1行ずつ読むジェネレータ"""
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
yield line.strip()
# 大きなファイルでもメモリを圧迫しない
for line in read_lines("large_file.txt"):
process(line)
無限シーケンス
def infinite_counter(start=0):
"""無限カウンター"""
n = start
while True:
yield n
n += 1
# 必要な分だけ取得
counter = infinite_counter()
for _ in range(5):
print(next(counter)) # 0, 1, 2, 3, 4
# itertoolsと組み合わせ
from itertools import islice
first_10 = list(islice(infinite_counter(100), 10))
print(first_10) # [100, 101, ..., 109]
フィボナッチ数列
def fibonacci():
"""フィボナッチ数列を生成"""
a, b = 0, 1
while True:
yield a
a, b = b, a + b
# 最初の10項
from itertools import islice
print(list(islice(fibonacci(), 10)))
# [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
yield from
ネストしたジェネレータを簡潔に書けます。
def nested_generator():
yield from [1, 2, 3]
yield from [4, 5, 6]
# これと同じ
def nested_generator_old():
for x in [1, 2, 3]:
yield x
for x in [4, 5, 6]:
yield x
print(list(nested_generator())) # [1, 2, 3, 4, 5, 6]
# 再帰的なジェネレータ
def flatten(nested):
"""ネストしたリストをフラット化"""
for item in nested:
if isinstance(item, list):
yield from flatten(item)
else:
yield item
nested = [1, [2, 3, [4, 5]], 6, [7]]
print(list(flatten(nested))) # [1, 2, 3, 4, 5, 6, 7]
send()でジェネレータに値を送る
def accumulator():
"""値を受け取って累積"""
total = 0
while True:
value = yield total
if value is not None:
total += value
acc = accumulator()
next(acc) # ジェネレータを開始
print(acc.send(10)) # 10
print(acc.send(20)) # 30
print(acc.send(5)) # 35
ジェネレータの終了
def controlled_generator():
try:
while True:
yield "running"
except GeneratorExit:
print("ジェネレータが閉じられました")
finally:
print("クリーンアップ")
gen = controlled_generator()
print(next(gen)) # "running"
gen.close() # GeneratorExitを発生させる
# 出力: "ジェネレータが閉じられました"
# "クリーンアップ"
itertoolsモジュール
from itertools import count, cycle, repeat, chain, islice, takewhile, dropwhile
# 無限イテレータ
for i in islice(count(10, 2), 5): # 10から2ずつ、5個
print(i) # 10, 12, 14, 16, 18
# サイクル
colors = cycle(["red", "green", "blue"])
for _ in range(5):
print(next(colors)) # red, green, blue, red, green
# チェーン(連結)
combined = chain([1, 2], [3, 4], [5, 6])
print(list(combined)) # [1, 2, 3, 4, 5, 6]
# 条件で取得/スキップ
numbers = [1, 3, 5, 7, 2, 4, 6]
print(list(takewhile(lambda x: x < 5, numbers))) # [1, 3]
print(list(dropwhile(lambda x: x < 5, numbers))) # [5, 7, 2, 4, 6]
カスタムイテレータ
class Countdown:
"""カウントダウンイテレータ"""
def __init__(self, start):
self.start = start
def __iter__(self):
return self
def __next__(self):
if self.start <= 0:
raise StopIteration
self.start -= 1
return self.start + 1
# 使用
for num in Countdown(5):
print(num) # 5, 4, 3, 2, 1
パイプライン処理
def read_data(filepath):
"""データを読み込む"""
with open(filepath) as f:
for line in f:
yield line.strip()
def parse_json(lines):
"""JSONをパース"""
import json
for line in lines:
if line:
yield json.loads(line)
def filter_active(records):
"""アクティブなレコードをフィルタ"""
for record in records:
if record.get("active"):
yield record
def transform(records):
"""データを変換"""
for record in records:
yield {
"id": record["id"],
"name": record["name"].upper()
}
# パイプラインを構築(遅延評価)
pipeline = transform(
filter_active(
parse_json(
read_data("data.jsonl")
)
)
)
# 必要な時に処理される
for item in pipeline:
print(item)
まとめ
yieldを使う関数はジェネレータ関数- ジェネレータは遅延評価でメモリ効率が良い
- 大量データや無限シーケンスに適している
yield fromでネストしたジェネレータを簡潔にitertoolsで便利なイテレータツールを利用- パイプライン処理で複雑なデータ処理を構築
次回はデコレータについて学びます。