summaryrefslogtreecommitdiff
path: root/cv/holiday_similarity/data_utils.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2020-08-08 20:21:47 +0800
committerzhang <zch921005@126.com>2020-08-08 20:21:47 +0800
commit2816f0ecda446dbd902bfab4a13d7bc95b0a5d33 (patch)
tree0aac7bda9692de91327231fa58a4540126548d3d /cv/holiday_similarity/data_utils.py
parent8ebc34e31433d73d630d1431acd80ce2e922395b (diff)
holiday similarity update
Diffstat (limited to 'cv/holiday_similarity/data_utils.py')
-rw-r--r--cv/holiday_similarity/data_utils.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/cv/holiday_similarity/data_utils.py b/cv/holiday_similarity/data_utils.py
new file mode 100644
index 0000000..1601706
--- /dev/null
+++ b/cv/holiday_similarity/data_utils.py
@@ -0,0 +1,183 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+from PIL import Image
+import itertools
+from random import shuffle
+from keras.utils import np_utils
+from keras.applications import imagenet_utils
+import itertools
+import os
+from random import shuffle
+
+import matplotlib.pyplot as plt
+import numpy as np
+from PIL import Image
+from keras.applications import imagenet_utils
+from keras.utils import np_utils
+
+IMAGE_DIR = '/Users/chunhuizhang/workspaces/00_datasets/images/INRIA Holidays dataset /jpg'
+# IMAGE_DIR = os.path.join(DATA_DIR, "holiday-photos")
+
+image_cache = {}
+
+
+def pair_generator(triples, image_cache, datagens, batch_size=32):
+ while True:
+ # shuffle once per batch
+ indices = np.random.permutation(np.arange(len(triples)))
+ num_batches = len(triples) // batch_size
+ for bid in range(num_batches):
+ batch_indices = indices[bid * batch_size: (bid + 1) * batch_size]
+ batch = [triples[i] for i in batch_indices]
+ X1 = np.zeros((batch_size, 224, 224, 3))
+ X2 = np.zeros((batch_size, 224, 224, 3))
+ Y = np.zeros((batch_size, 2))
+ for i, (image_filename_l, image_filename_r, label) in enumerate(batch):
+ if datagens is None or len(datagens) == 0:
+ X1[i] = image_cache[image_filename_l]
+ X2[i] = image_cache[image_filename_r]
+ else:
+ X1[i] = datagens[0].random_transform(image_cache[image_filename_l])
+ X2[i] = datagens[1].random_transform(image_cache[image_filename_r])
+ Y[i] = [1, 0] if label == 0 else [0, 1]
+ yield [X1, X2], Y
+
+
+def image_batch_generator(image_names, batch_size):
+ num_batches = len(image_names) // batch_size
+ for i in range(num_batches):
+ batch = image_names[i * batch_size: (i + 1) * batch_size]
+ yield batch
+ batch = image_names[(i + 1) * batch_size:]
+ yield batch
+
+
+def show_img(sid, img_file, img_title):
+ plt.subplot(sid)
+ plt.title(img_title)
+ plt.xticks([])
+ plt.yticks([])
+ img = np.asarray(Image.fromarray(plt.imread(img_file)).resize((512, 512)))
+ plt.imshow(img)
+
+
+def get_random_image(img_groups, group_names, gid):
+ gname = group_names[gid]
+ photos = img_groups[gname]
+ pid = np.random.choice(np.arange(len(photos)), size=1)[0]
+ pname = photos[pid]
+ return gname + pname + ".jpg"
+
+
+def create_triples(image_dir):
+ img_groups = {}
+ for img_file in os.listdir(image_dir):
+ prefix, suffix = img_file.split(".")
+ gid, pid = prefix[0:4], prefix[4:]
+ if gid in img_groups:
+ img_groups[gid].append(pid)
+ else:
+ img_groups[gid] = [pid]
+ pos_triples, neg_triples = [], []
+ # positive pairs are any combination of images in same group
+ for key in img_groups.keys():
+ triples = [(key + x[0] + ".jpg", key + x[1] + ".jpg", 1)
+ for x in itertools.combinations(img_groups[key], 2)]
+ pos_triples.extend(triples)
+ # need equal number of negative examples
+ group_names = list(img_groups.keys())
+ # pos:neg == 1:1
+ for i in range(len(pos_triples)):
+ g1, g2 = np.random.choice(np.arange(len(group_names)), size=2, replace=False)
+ left = get_random_image(img_groups, group_names, g1)
+ right = get_random_image(img_groups, group_names, g2)
+ neg_triples.append((left, right, 0))
+ pos_triples.extend(neg_triples)
+ shuffle(pos_triples)
+ return pos_triples
+
+
+def load_image(image_name, imagenet=False):
+ if image_name not in image_cache:
+ image = plt.imread(os.path.join(IMAGE_DIR, image_name)).astype(np.float32)
+ image = image.astype(np.uint8)
+ image = np.asarray(Image.fromarray(image).resize((224, 224)))
+ if imagenet:
+ image = imagenet_utils.preprocess_input(image)
+ else:
+ image = np.divide(image, 256)
+ image_cache[image_name] = image
+ return image_cache[image_name]
+
+
+def generate_image_triples_batch(image_triples, batch_size, shuffle=False):
+ while True:
+ # loop once per epoch
+ if shuffle:
+ indices = np.random.permutation(np.arange(len(image_triples)))
+ else:
+ indices = np.arange(len(image_triples))
+ shuffled_triples = [image_triples[ix] for ix in indices]
+ num_batches = len(shuffled_triples) // batch_size
+ for bid in range(num_batches):
+ # loop once per batch
+ images_left, images_right, labels = [], [], []
+ batch = shuffled_triples[bid * batch_size: (bid + 1) * batch_size]
+ for i in range(batch_size):
+ lhs, rhs, label = batch[i]
+ images_left.append(load_image(lhs, imagenet=True))
+ images_right.append(load_image(rhs))
+ labels.append(label)
+ Xlhs = np.array(images_left)
+ Xrhs = np.array(images_right)
+ Y = np_utils.to_categorical(np.array(labels), num_classes=2)
+ yield ([Xlhs, Xrhs], Y)
+
+
+def train_test_split(triples, splits):
+ assert sum(splits) == 1.0
+ split_pts = np.cumsum(np.array([0.] + splits))
+ indices = np.random.permutation(np.arange(len(triples)))
+ shuffled_triples = [triples[i] for i in indices]
+ data_splits = []
+ for sid in range(len(splits)):
+ start = int(split_pts[sid] * len(triples))
+ end = int(split_pts[sid + 1] * len(triples))
+ data_splits.append(shuffled_triples[start:end])
+ return data_splits
+
+
+def batch_to_vectors(batch, vec_size, vec_dict):
+ X1 = np.zeros((len(batch), vec_size))
+ X2 = np.zeros((len(batch), vec_size))
+ Y = np.zeros((len(batch), 2))
+ for tid in range(len(batch)):
+ X1[tid] = vec_dict[batch[tid][0]]
+ X2[tid] = vec_dict[batch[tid][1]]
+ Y[tid] = [1, 0] if batch[tid][2] == 0 else [0, 1]
+ return ([X1, X2], Y)
+
+
+def data_generator(triples, vec_size, vec_dict, batch_size=32):
+ while True:
+ # shuffle once per batch
+ indices = np.random.permutation(np.arange(len(triples)))
+ num_batches = len(triples) // batch_size
+ for bid in range(num_batches):
+ batch_indices = indices[bid * batch_size: (bid + 1) * batch_size]
+ batch = [triples[i] for i in batch_indices]
+ yield batch_to_vectors(batch, vec_size, vec_dict)
+
+
+if __name__ == '__main__':
+ show_img(131, os.path.join(IMAGE_DIR, "115200.jpg"), "original")
+ show_img(132, os.path.join(IMAGE_DIR, "115201.jpg"), "similar")
+ show_img(133, os.path.join(IMAGE_DIR, "123700.jpg"), "different")
+ plt.tight_layout()
+ plt.show()
+
+ triples_data = create_triples(IMAGE_DIR)
+
+ print("# image triples:", len(triples_data))
+ print([x for x in triples_data[0:5]])