summaryrefslogtreecommitdiff
path: root/cv/holiday_similarity/main.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2020-08-08 20:21:47 +0800
committerzhang <zch921005@126.com>2020-08-08 20:21:47 +0800
commit2816f0ecda446dbd902bfab4a13d7bc95b0a5d33 (patch)
tree0aac7bda9692de91327231fa58a4540126548d3d /cv/holiday_similarity/main.py
parent8ebc34e31433d73d630d1431acd80ce2e922395b (diff)
holiday similarity update
Diffstat (limited to 'cv/holiday_similarity/main.py')
-rw-r--r--cv/holiday_similarity/main.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/cv/holiday_similarity/main.py b/cv/holiday_similarity/main.py
new file mode 100644
index 0000000..a292aab
--- /dev/null
+++ b/cv/holiday_similarity/main.py
@@ -0,0 +1,28 @@
+from cv.holiday_similarity import vis_utils
+from cv.holiday_similarity.data_utils import *
+from cv.holiday_similarity.prepare_models import e2e_pretrained_network
+
+BATCH_SIZE = 64
+
+triples_data = create_triples(IMAGE_DIR)
+split_point = int(len(triples_data) * 0.7)
+triples_train, triples_test = triples_data[0:split_point], triples_data[split_point:]
+
+NUM_EPOCHS = 10
+
+image_cache = {}
+train_gen = generate_image_triples_batch(triples_train, BATCH_SIZE, shuffle=True)
+val_gen = generate_image_triples_batch(triples_test, BATCH_SIZE, shuffle=False)
+
+num_train_steps = len(triples_train) // BATCH_SIZE
+num_val_steps = len(triples_test) // BATCH_SIZE
+
+# nn = e2e_network()
+nn = e2e_pretrained_network()
+history = nn.fit_generator(train_gen,
+ steps_per_epoch=num_train_steps,
+ epochs=NUM_EPOCHS,
+ validation_data=val_gen,
+ validation_steps=num_val_steps)
+
+vis_utils.plot_training_curve(history)