From 2816f0ecda446dbd902bfab4a13d7bc95b0a5d33 Mon Sep 17 00:00:00 2001 From: zhang Date: Sat, 8 Aug 2020 20:21:47 +0800 Subject: holiday similarity update --- cv/holiday_similarity/main.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 cv/holiday_similarity/main.py (limited to 'cv/holiday_similarity/main.py') diff --git a/cv/holiday_similarity/main.py b/cv/holiday_similarity/main.py new file mode 100644 index 0000000..a292aab --- /dev/null +++ b/cv/holiday_similarity/main.py @@ -0,0 +1,28 @@ +from cv.holiday_similarity import vis_utils +from cv.holiday_similarity.data_utils import * +from cv.holiday_similarity.prepare_models import e2e_pretrained_network + +BATCH_SIZE = 64 + +triples_data = create_triples(IMAGE_DIR) +split_point = int(len(triples_data) * 0.7) +triples_train, triples_test = triples_data[0:split_point], triples_data[split_point:] + +NUM_EPOCHS = 10 + +image_cache = {} +train_gen = generate_image_triples_batch(triples_train, BATCH_SIZE, shuffle=True) +val_gen = generate_image_triples_batch(triples_test, BATCH_SIZE, shuffle=False) + +num_train_steps = len(triples_train) // BATCH_SIZE +num_val_steps = len(triples_test) // BATCH_SIZE + +# nn = e2e_network() +nn = e2e_pretrained_network() +history = nn.fit_generator(train_gen, + steps_per_epoch=num_train_steps, + epochs=NUM_EPOCHS, + validation_data=val_gen, + validation_steps=num_val_steps) + +vis_utils.plot_training_curve(history) -- cgit v1.2.3