summaryrefslogtreecommitdiff
path: root/cv/holiday_similarity/main.py
blob: a292aab9a1640e822c48a74c0551a16ed119a1eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from cv.holiday_similarity import vis_utils
from cv.holiday_similarity.data_utils import *
from cv.holiday_similarity.prepare_models import e2e_pretrained_network

BATCH_SIZE = 64

triples_data = create_triples(IMAGE_DIR)
split_point = int(len(triples_data) * 0.7)
triples_train, triples_test = triples_data[0:split_point], triples_data[split_point:]

NUM_EPOCHS = 10

image_cache = {}
train_gen = generate_image_triples_batch(triples_train, BATCH_SIZE, shuffle=True)
val_gen = generate_image_triples_batch(triples_test, BATCH_SIZE, shuffle=False)

num_train_steps = len(triples_train) // BATCH_SIZE
num_val_steps = len(triples_test) // BATCH_SIZE

# nn = e2e_network()
nn = e2e_pretrained_network()
history = nn.fit_generator(train_gen,
                           steps_per_epoch=num_train_steps,
                           epochs=NUM_EPOCHS,
                           validation_data=val_gen,
                           validation_steps=num_val_steps)

vis_utils.plot_training_curve(history)