tf2 自定義loss加載報錯
作者:互联网
問題描述
ValueError: Unknown loss function: bes_loss
問題場景
- 訓練
margin = 0.6
theta = lambda t : (K.sign(t) + 1.) / 2
def bes_loss(y_true, y_pred):
return - (1 - theta(y_true - margin) * theta(y_pred - margin)
- theta(1 - margin - y_true) * theta(1 - margin - y_pred)
) * (y_true * K.log(y_pred + 1e-8) + (1 - y_true) * K.log(1 - y_pred + 1e-8))
···
model.compile(tf.optimizers.Adam(), loss=bes_loss,metrics=['accuracy'])
- 測試
model = load_model(config.model_path, custom_objects={'bes_loss':bes_loss})
這樣的加載方式就會出現報錯,如問題描述
問題解決
model = load_model(config.model_path, custom_objects={'bes_loss':bes_loss}, compile = False)
model.compile(tf.optimizers.Adam(), loss=bes_loss, metrics=['accuracy'])
通過compile=False忽略加載錯誤報錯,然後再通過model.compile()加載模型的配置
标签:loss,pred,加載,theta,bes,model,true,報錯 来源: https://www.cnblogs.com/monkeyT/p/16107119.html