您的位置:首页 > 房产 > 建筑 > Keras深度学习框架第三十讲:在KerasTuner中处理失败的训练

Keras深度学习框架第三十讲:在KerasTuner中处理失败的训练

2025/1/23 13:43:46 来源:https://blog.csdn.net/MUKAMO/article/details/139247201  浏览:    关键词:Keras深度学习框架第三十讲:在KerasTuner中处理失败的训练

1、绪论

KerasTuner程序可能会运行很长时间,因为每个模型的训练可能需要很长时间。我们不希望程序仅仅因为一些试验随机失败就停止运行。

本文将讨论在KerasTuner中如何处理失败的训练,包括:

  • 如何在搜索过程中容忍失败的训练
  • 如何在构建和评估模型时将某个训练标记为失败
  • 如何通过抛出FatalError来终止搜索过程

正式讨论使用前,需要进行如下的设置

!pip install keras-tuner -q
import keras
from keras import layers
import keras_tuner
import numpy as np

2、容忍失败的训练

在初始化tuner时,我们将使用max_retries_per_trialmax_consecutive_failed_trials参数。

max_retries_per_trial控制如果一个试验持续失败时,允许重试的最大次数。例如,如果它被设置为3,那么该试验可能会运行4次(1次失败的运行 + 3次失败的重试),之后才最终被标记为失败。max_retries_per_trial的默认值是0。

max_consecutive_failed_trials控制在终止搜索之前,允许连续失败多少个训练(这里的失败试验指的是一个训练在所有的重试中都失败了)。例如,如果它被设置为3,并且试验2、试验3和试验4都失败了,搜索就会终止。但是,如果它被设置为3,但只有试验2、试验3、试验5和试验6失败,搜索不会终止,因为失败的试验不是连续的。max_consecutive_failed_trials的默认值是3。

以下代码展示了这两个参数如何工作。

我们定义了一个搜索空间,包含两个超参数,用于确定两个密集层的单元数。当这两个数的乘积大于800时,我们会为模型太大而引发ValueError

def build_model(hp):# Define the 2 hyperparameters for the units in dense layersunits_1 = hp.Int("units_1", 10, 40, step=10)units_2 = hp.Int("units_2", 10, 30, step=10)# Define the modelmodel = keras.Sequential([layers.Dense(units=units_1, input_shape=(20,)),layers.Dense(units=units_2),layers.Dense(units=1),])model.compile(loss="mse")# Raise an error when the model is too largenum_params = model.count_params()if num_params > 1200:raise ValueError(f"Model too large! It contains {num_params} params.")return model

我们按照以下方式设置tuner。

我们将max_retries_per_trial设置为3。
我们将max_consecutive_failed_trials设置为8。
我们使用GridSearch来枚举所有超参数值的组合。

这样设置后,如果在某个超参数组合下模型训练失败,KerasTuner会为该试验重试最多3次。如果某个试验在所有的重试中都失败了,并且连续有8个这样的失败试验,那么搜索过程将被终止。而GridSearch则会遍历搜索空间中的每个超参数组合,无论它们是否导致试验失败。

tuner = keras_tuner.GridSearch(hypermodel=build_model,objective="val_loss",overwrite=True,max_retries_per_trial=3,max_consecutive_failed_trials=8,
)# Use random data to train the model.
tuner.search(x=np.random.rand(100, 20),y=np.random.rand(100, 1),validation_data=(np.random.rand(100, 20),np.random.rand(100, 1),),epochs=10,
)# Print the results.
tuner.results_summary()

3、将试验标记为失败

当模型过大时,我们不需要重新尝试它。无论使用相同的超参数尝试多少次,模型都会过大。

我们可以将max_retries_per_trial设置为0来实现这一点。但是,这样做的话,无论出现什么错误都不会重试,而我们可能仍然希望对于其他意外的错误进行重试。有没有更好的方式来处理这种情况?

我们可以引发FailedTrialError来跳过重试。每当引发此错误时,试验将不会被重试。当发生其他错误时,重试仍会进行。以下是一个示例:

def build_model(hp):# Define the 2 hyperparameters for the units in dense layersunits_1 = hp.Int("units_1", 10, 40, step=10)units_2 = hp.Int("units_2", 10, 30, step=10)# Define the modelmodel = keras.Sequential([layers.Dense(units=units_1, input_shape=(20,)),layers.Dense(units=units_2),layers.Dense(units=1),])model.compile(loss="mse")# Raise an error when the model is too largenum_params = model.count_params()if num_params > 1200:# When this error is raised, it skips the retries.raise keras_tuner.errors.FailedTrialError(f"Model too large! It contains {num_params} params.")return modeltuner = keras_tuner.GridSearch(hypermodel=build_model,objective="val_loss",overwrite=True,max_retries_per_trial=3,max_consecutive_failed_trials=8,
)# Use random data to train the model.
tuner.search(x=np.random.rand(100, 20),y=np.random.rand(100, 1),validation_data=(np.random.rand(100, 20),np.random.rand(100, 1),),epochs=10,
)# Print the results.
tuner.results_summary()

4、程序化终止搜索

当代码中存在错误时,我们应该立即终止搜索并修复该错误。当满足您定义的条件时,程序员可以程序化地终止搜索。引发FatalError(或其子类FatalValueError、FatalTypeError或FatalRuntimeError)将终止搜索,而不管max_consecutive_failed_trials参数的值如何。

以下是一个当模型过大时终止搜索的示例:

def build_model(hp):# Define the 2 hyperparameters for the units in dense layersunits_1 = hp.Int("units_1", 10, 40, step=10)units_2 = hp.Int("units_2", 10, 30, step=10)# Define the modelmodel = keras.Sequential([layers.Dense(units=units_1, input_shape=(20,)),layers.Dense(units=units_2),layers.Dense(units=1),])model.compile(loss="mse")# Raise an error when the model is too largenum_params = model.count_params()if num_params > 1200:# When this error is raised, the search is terminated.raise keras_tuner.errors.FatalError(f"Model too large! It contains {num_params} params.")return modeltuner = keras_tuner.GridSearch(hypermodel=build_model,objective="val_loss",overwrite=True,max_retries_per_trial=3,max_consecutive_failed_trials=8,
)try:# Use random data to train the model.tuner.search(x=np.random.rand(100, 20),y=np.random.rand(100, 1),validation_data=(np.random.rand(100, 20),np.random.rand(100, 1),),epochs=10,)
except keras_tuner.errors.FatalError:print("The search is terminated.")

5、总结

本文讨论学习了如何在KerasTuner中处理失败的试验:

在KerasTuner中处理失败的试验有几种不同的方法,具体取决于失败的原因和你想要的行为。以下是几种常见的策略:

5.1 设置重试次数 (max_retries_per_trial)

如果你知道某些类型的错误可能是暂时的(例如,由于网络中断或硬件问题),你可以通过设置max_retries_per_trial参数来指定在标记一个试验为失败之前应该重试多少次。这样,即使试验在首次尝试时失败,Tuner也会尝试使用相同的超参数重新运行试验。

tuner = RandomSearch(build_model,objective='val_loss',max_trials=10,executions_per_trial=1,max_retries_per_trial=3,  # 最多重试3次# ... 其他参数 ...
)

5.2 直接标记试验为失败 (FailedTrialError)

如果你能在构建模型或运行试验的过程中检测到某些条件将导致失败(例如,模型大小超出限制),你可以引发FailedTrialError来直接标记该试验为失败。这将导致Tuner跳过该试验的重试,并继续尝试其他超参数组合。

from keras_tuner.engine.trial import FailedTrialErrordef build_model(hp):# ... 省略模型构建代码 ...if some_condition_that_will_fail:raise FailedTrialError('Model configuration will fail due to ...')# ... 省略模型其余部分的构建代码 ...

5.3 立即终止搜索 (FatalError 及其子类)

如果你遇到了一个严重的错误,该错误表明搜索无法继续进行(例如,数据加载错误或模型构建中的根本性错误),你可以引发FatalError或其子类(如FatalValueErrorFatalTypeErrorFatalRuntimeError)来立即终止搜索。这将停止Tuner的所有活动,并允许你修复代码中的错误。

from keras_tuner.engine.trial import FatalErrordef build_model(hp):# ... 省略模型构建代码 ...if some_unrecoverable_error:raise FatalError('Unrecoverable error occurred: ...')# ... 省略模型其余部分的构建代码 ...

5.4 处理异常并继续

有时,你可能想要捕获异常并进行一些处理(例如,记录错误或尝试使用不同的策略),而不是直接标记试验为失败或终止搜索。你可以使用Python的异常处理机制(try/except块)来实现这一点。

def build_model(hp):try:# ... 省略模型构建代码 ...# 这里可能会引发异常except SomeSpecificError as e:# 处理异常,例如记录日志或采取其他措施print(f'Caught an exception: {e}')# 但不引发FailedTrialError或FatalError,以便Tuner可以继续# ... 省略模型其余部分的构建代码 ...

5.5 自定义回调

你还可以使用KerasTuner的回调机制来在试验的不同阶段执行自定义逻辑。例如,你可以在试验开始前、结束后或每个epoch结束时运行自定义函数,以检查试验的状态或执行其他操作。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com