From 2816f0ecda446dbd902bfab4a13d7bc95b0a5d33 Mon Sep 17 00:00:00 2001 From: zhang Date: Sat, 8 Aug 2020 20:21:47 +0800 Subject: holiday similarity update --- cv/holiday_similarity/eval_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 cv/holiday_similarity/eval_utils.py (limited to 'cv/holiday_similarity/eval_utils.py') diff --git a/cv/holiday_similarity/eval_utils.py b/cv/holiday_similarity/eval_utils.py new file mode 100644 index 0000000..b83125e --- /dev/null +++ b/cv/holiday_similarity/eval_utils.py @@ -0,0 +1,24 @@ +import numpy as np +from sklearn.metrics import accuracy_score, confusion_matrix, classification_report + + +def evaluate_model(model, test_gen, test_triples, batch_size): + # model_name = os.path.basename(model_file) + # model = load_model(model_file) + # print("=== Evaluating model: {:s} ===".format(model_name)) + print("=== Evaluating model") + ytrue, ypred = [], [] + num_test_steps = len(test_triples) // batch_size + for i in range(num_test_steps): + # (X1, X2), Y = test_gen.next() + (X1, X2), Y = next(test_gen) + Y_ = model.predict([X1, X2]) + ytrue.extend(np.argmax(Y, axis=1).tolist()) + ypred.extend(np.argmax(Y_, axis=1).tolist()) + accuracy = accuracy_score(ytrue, ypred) + print("\nAccuracy: {:.3f}".format(accuracy)) + print("\nConfusion Matrix") + print(confusion_matrix(ytrue, ypred)) + print("\nClassification Report") + print(classification_report(ytrue, ypred)) + return accuracy -- cgit v1.2.3