summaryrefslogtreecommitdiff
path: root/cv/holiday_similarity/eval_utils.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2020-08-08 20:21:47 +0800
committerzhang <zch921005@126.com>2020-08-08 20:21:47 +0800
commit2816f0ecda446dbd902bfab4a13d7bc95b0a5d33 (patch)
tree0aac7bda9692de91327231fa58a4540126548d3d /cv/holiday_similarity/eval_utils.py
parent8ebc34e31433d73d630d1431acd80ce2e922395b (diff)
holiday similarity update
Diffstat (limited to 'cv/holiday_similarity/eval_utils.py')
-rw-r--r--cv/holiday_similarity/eval_utils.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/cv/holiday_similarity/eval_utils.py b/cv/holiday_similarity/eval_utils.py
new file mode 100644
index 0000000..b83125e
--- /dev/null
+++ b/cv/holiday_similarity/eval_utils.py
@@ -0,0 +1,24 @@
+import numpy as np
+from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
+
+
+def evaluate_model(model, test_gen, test_triples, batch_size):
+ # model_name = os.path.basename(model_file)
+ # model = load_model(model_file)
+ # print("=== Evaluating model: {:s} ===".format(model_name))
+ print("=== Evaluating model")
+ ytrue, ypred = [], []
+ num_test_steps = len(test_triples) // batch_size
+ for i in range(num_test_steps):
+ # (X1, X2), Y = test_gen.next()
+ (X1, X2), Y = next(test_gen)
+ Y_ = model.predict([X1, X2])
+ ytrue.extend(np.argmax(Y, axis=1).tolist())
+ ypred.extend(np.argmax(Y_, axis=1).tolist())
+ accuracy = accuracy_score(ytrue, ypred)
+ print("\nAccuracy: {:.3f}".format(accuracy))
+ print("\nConfusion Matrix")
+ print(confusion_matrix(ytrue, ypred))
+ print("\nClassification Report")
+ print(classification_report(ytrue, ypred))
+ return accuracy