其他分享
首页 > 其他分享> > tf2 自定義loss加載報錯

tf2 自定義loss加載報錯

作者:互联网

問題描述

ValueError: Unknown loss function: bes_loss

問題場景

margin = 0.6
theta = lambda t : (K.sign(t) + 1.) / 2
def bes_loss(y_true, y_pred):
    return - (1 - theta(y_true - margin) * theta(y_pred - margin)
            - theta(1 - margin - y_true) * theta(1 - margin - y_pred)
         ) * (y_true * K.log(y_pred + 1e-8) + (1 - y_true) * K.log(1 - y_pred + 1e-8))
···
model.compile(tf.optimizers.Adam(), loss=bes_loss,metrics=['accuracy'])
model = load_model(config.model_path, custom_objects={'bes_loss':bes_loss})

這樣的加載方式就會出現報錯,如問題描述

問題解決

model = load_model(config.model_path, custom_objects={'bes_loss':bes_loss}, compile = False)
model.compile(tf.optimizers.Adam(), loss=bes_loss, metrics=['accuracy'])

通過compile=False忽略加載錯誤報錯,然後再通過model.compile()加載模型的配置

标签:loss,pred,加載,theta,bes,model,true,報錯
来源: https://www.cnblogs.com/monkeyT/p/16107119.html