diff options
| author | zhang <zch921005@126.com> | 2020-08-08 20:21:47 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2020-08-08 20:21:47 +0800 |
| commit | 2816f0ecda446dbd902bfab4a13d7bc95b0a5d33 (patch) | |
| tree | 0aac7bda9692de91327231fa58a4540126548d3d /cv/holiday_similarity/main.py | |
| parent | 8ebc34e31433d73d630d1431acd80ce2e922395b (diff) | |
holiday similarity update
Diffstat (limited to 'cv/holiday_similarity/main.py')
| -rw-r--r-- | cv/holiday_similarity/main.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/cv/holiday_similarity/main.py b/cv/holiday_similarity/main.py new file mode 100644 index 0000000..a292aab --- /dev/null +++ b/cv/holiday_similarity/main.py @@ -0,0 +1,28 @@ +from cv.holiday_similarity import vis_utils +from cv.holiday_similarity.data_utils import * +from cv.holiday_similarity.prepare_models import e2e_pretrained_network + +BATCH_SIZE = 64 + +triples_data = create_triples(IMAGE_DIR) +split_point = int(len(triples_data) * 0.7) +triples_train, triples_test = triples_data[0:split_point], triples_data[split_point:] + +NUM_EPOCHS = 10 + +image_cache = {} +train_gen = generate_image_triples_batch(triples_train, BATCH_SIZE, shuffle=True) +val_gen = generate_image_triples_batch(triples_test, BATCH_SIZE, shuffle=False) + +num_train_steps = len(triples_train) // BATCH_SIZE +num_val_steps = len(triples_test) // BATCH_SIZE + +# nn = e2e_network() +nn = e2e_pretrained_network() +history = nn.fit_generator(train_gen, + steps_per_epoch=num_train_steps, + epochs=NUM_EPOCHS, + validation_data=val_gen, + validation_steps=num_val_steps) + +vis_utils.plot_training_curve(history) |
