fit_generator()

2018-04-03  本文已影响0人  Manfestain

在数据处理和网络定义完成后,跑模型时突然出现了错误: OOM

刚开始也不知道哪里的问题,发现有可能是内存耗尽了,然后就放进去500张图片进行fit,然后问题就消失了,猜想应该是数据太大,内存开销不够。
发现官方文档中说可以使用fit_generator()分批训练。


官方文档如下:

fit_generator(self, generator, 
                    steps_per_epoch=None, 
                    epochs=1, 
                    verbose=1, 
                    callbacks=None, 
                    validation_data=None, 
                    validation_steps=None,  
                    class_weight=None,
                    max_queue_size=10,   
                    workers=1, 
                    use_multiprocessing=False, 
                    shuffle=True, 
                    initial_epoch=0)

通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。

参数:

例子:
datagen = ImageDataGenator(...)
model.fit_generator(datagen.flow(x_train, y_train,
                                 batch_size=batch_size),
                    epochs=epochs,
                    validation_data=(x_test, y_test),
                    workers=4)
上一篇 下一篇

猜你喜欢

热点阅读