diff options
| author | zhang <zch921005@126.com> | 2020-08-08 20:21:47 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2020-08-08 20:21:47 +0800 |
| commit | 2816f0ecda446dbd902bfab4a13d7bc95b0a5d33 (patch) | |
| tree | 0aac7bda9692de91327231fa58a4540126548d3d /cv/holiday_similarity/data_utils.py | |
| parent | 8ebc34e31433d73d630d1431acd80ce2e922395b (diff) | |
holiday similarity update
Diffstat (limited to 'cv/holiday_similarity/data_utils.py')
| -rw-r--r-- | cv/holiday_similarity/data_utils.py | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/cv/holiday_similarity/data_utils.py b/cv/holiday_similarity/data_utils.py new file mode 100644 index 0000000..1601706 --- /dev/null +++ b/cv/holiday_similarity/data_utils.py @@ -0,0 +1,183 @@ +import matplotlib.pyplot as plt +import numpy as np +import os +from PIL import Image +import itertools +from random import shuffle +from keras.utils import np_utils +from keras.applications import imagenet_utils +import itertools +import os +from random import shuffle + +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image +from keras.applications import imagenet_utils +from keras.utils import np_utils + +IMAGE_DIR = '/Users/chunhuizhang/workspaces/00_datasets/images/INRIA Holidays dataset /jpg' +# IMAGE_DIR = os.path.join(DATA_DIR, "holiday-photos") + +image_cache = {} + + +def pair_generator(triples, image_cache, datagens, batch_size=32): + while True: + # shuffle once per batch + indices = np.random.permutation(np.arange(len(triples))) + num_batches = len(triples) // batch_size + for bid in range(num_batches): + batch_indices = indices[bid * batch_size: (bid + 1) * batch_size] + batch = [triples[i] for i in batch_indices] + X1 = np.zeros((batch_size, 224, 224, 3)) + X2 = np.zeros((batch_size, 224, 224, 3)) + Y = np.zeros((batch_size, 2)) + for i, (image_filename_l, image_filename_r, label) in enumerate(batch): + if datagens is None or len(datagens) == 0: + X1[i] = image_cache[image_filename_l] + X2[i] = image_cache[image_filename_r] + else: + X1[i] = datagens[0].random_transform(image_cache[image_filename_l]) + X2[i] = datagens[1].random_transform(image_cache[image_filename_r]) + Y[i] = [1, 0] if label == 0 else [0, 1] + yield [X1, X2], Y + + +def image_batch_generator(image_names, batch_size): + num_batches = len(image_names) // batch_size + for i in range(num_batches): + batch = image_names[i * batch_size: (i + 1) * batch_size] + yield batch + batch = image_names[(i + 1) * batch_size:] + yield batch + + +def show_img(sid, img_file, img_title): + plt.subplot(sid) + plt.title(img_title) + plt.xticks([]) + plt.yticks([]) + img = np.asarray(Image.fromarray(plt.imread(img_file)).resize((512, 512))) + plt.imshow(img) + + +def get_random_image(img_groups, group_names, gid): + gname = group_names[gid] + photos = img_groups[gname] + pid = np.random.choice(np.arange(len(photos)), size=1)[0] + pname = photos[pid] + return gname + pname + ".jpg" + + +def create_triples(image_dir): + img_groups = {} + for img_file in os.listdir(image_dir): + prefix, suffix = img_file.split(".") + gid, pid = prefix[0:4], prefix[4:] + if gid in img_groups: + img_groups[gid].append(pid) + else: + img_groups[gid] = [pid] + pos_triples, neg_triples = [], [] + # positive pairs are any combination of images in same group + for key in img_groups.keys(): + triples = [(key + x[0] + ".jpg", key + x[1] + ".jpg", 1) + for x in itertools.combinations(img_groups[key], 2)] + pos_triples.extend(triples) + # need equal number of negative examples + group_names = list(img_groups.keys()) + # pos:neg == 1:1 + for i in range(len(pos_triples)): + g1, g2 = np.random.choice(np.arange(len(group_names)), size=2, replace=False) + left = get_random_image(img_groups, group_names, g1) + right = get_random_image(img_groups, group_names, g2) + neg_triples.append((left, right, 0)) + pos_triples.extend(neg_triples) + shuffle(pos_triples) + return pos_triples + + +def load_image(image_name, imagenet=False): + if image_name not in image_cache: + image = plt.imread(os.path.join(IMAGE_DIR, image_name)).astype(np.float32) + image = image.astype(np.uint8) + image = np.asarray(Image.fromarray(image).resize((224, 224))) + if imagenet: + image = imagenet_utils.preprocess_input(image) + else: + image = np.divide(image, 256) + image_cache[image_name] = image + return image_cache[image_name] + + +def generate_image_triples_batch(image_triples, batch_size, shuffle=False): + while True: + # loop once per epoch + if shuffle: + indices = np.random.permutation(np.arange(len(image_triples))) + else: + indices = np.arange(len(image_triples)) + shuffled_triples = [image_triples[ix] for ix in indices] + num_batches = len(shuffled_triples) // batch_size + for bid in range(num_batches): + # loop once per batch + images_left, images_right, labels = [], [], [] + batch = shuffled_triples[bid * batch_size: (bid + 1) * batch_size] + for i in range(batch_size): + lhs, rhs, label = batch[i] + images_left.append(load_image(lhs, imagenet=True)) + images_right.append(load_image(rhs)) + labels.append(label) + Xlhs = np.array(images_left) + Xrhs = np.array(images_right) + Y = np_utils.to_categorical(np.array(labels), num_classes=2) + yield ([Xlhs, Xrhs], Y) + + +def train_test_split(triples, splits): + assert sum(splits) == 1.0 + split_pts = np.cumsum(np.array([0.] + splits)) + indices = np.random.permutation(np.arange(len(triples))) + shuffled_triples = [triples[i] for i in indices] + data_splits = [] + for sid in range(len(splits)): + start = int(split_pts[sid] * len(triples)) + end = int(split_pts[sid + 1] * len(triples)) + data_splits.append(shuffled_triples[start:end]) + return data_splits + + +def batch_to_vectors(batch, vec_size, vec_dict): + X1 = np.zeros((len(batch), vec_size)) + X2 = np.zeros((len(batch), vec_size)) + Y = np.zeros((len(batch), 2)) + for tid in range(len(batch)): + X1[tid] = vec_dict[batch[tid][0]] + X2[tid] = vec_dict[batch[tid][1]] + Y[tid] = [1, 0] if batch[tid][2] == 0 else [0, 1] + return ([X1, X2], Y) + + +def data_generator(triples, vec_size, vec_dict, batch_size=32): + while True: + # shuffle once per batch + indices = np.random.permutation(np.arange(len(triples))) + num_batches = len(triples) // batch_size + for bid in range(num_batches): + batch_indices = indices[bid * batch_size: (bid + 1) * batch_size] + batch = [triples[i] for i in batch_indices] + yield batch_to_vectors(batch, vec_size, vec_dict) + + +if __name__ == '__main__': + show_img(131, os.path.join(IMAGE_DIR, "115200.jpg"), "original") + show_img(132, os.path.join(IMAGE_DIR, "115201.jpg"), "similar") + show_img(133, os.path.join(IMAGE_DIR, "123700.jpg"), "different") + plt.tight_layout() + plt.show() + + triples_data = create_triples(IMAGE_DIR) + + print("# image triples:", len(triples_data)) + print([x for x in triples_data[0:5]]) |
