tf.estimator.Estimator.predict预测信息如何获取?网上没答案!请教!
def model_fn(features, labels, mode, params):
......
predictions={
'probability': tf.reshape(predictions, [1, -1])
}
export_outputs={
'predictions': tf.estimator.export.PredictOutput(predictions)
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions,
export_outputs=export_outputs)
def main(_):
model=tf.estimator.Estimator(
model_fn=model_fn,
......
train_spec=tf.estimator.TrainSpec(
input_fn=lambda: input_fn(mode='train', pattern='D:/din/dataset/*'))
eval_spec=tf.estimator.EvalSpec(
input_fn=lambda: input_fn(mode='eval', pattern='D:/din/dataset/*'),
steps=100, throttle_secs=60 )
tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
test_input_fn=lambda: input_fn(mode='predict', pattern='D:/din/dataset/*')
test_spec=tf.estimator.EvalSpec(input_fn=test_input_fn)
predis=model.predict(test_spec)#仅得对象<generator object Estimator.predict at 0x0000020BFD59B390>难取信息,咋办?
0 个回答
暂无回答