summaryrefslogtreecommitdiff
path: root/cv/seg/utils.py
blob: 0329c8414670826e081fa133c06cc18c2691cf9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python

"""
Function for interactively selecting part of an array displayed as an image with matplotlib.
"""

import matplotlib.pyplot as plt
from matplotlib import is_interactive
from matplotlib.path import Path
from matplotlib.widgets import LassoSelector, RectangleSelector
import numpy as np


def path_bbox(p):
    """
    Return rectangular bounding box of given path.
    Parameters
    ----------
    p : array_like
        Array of vertices with shape Nx2.
    Returns
    -------
    bbox : array_like
        Array of bounding box vertices with shape 4x2.
    """

    assert p.ndim == 2
    assert p.shape[1] == 2

    ix_min = p[:, 0].argmin()
    ix_max = p[:, 0].argmax()
    iy_min = p[:, 1].argmin()
    iy_max = p[:, 1].argmax()

    return np.array([[p[ix_min, 0], p[iy_min, 1]],
                     [p[ix_min, 0], p[iy_max, 1]],
                     [p[ix_max, 0], p[iy_max, 1]],
                     [p[ix_max, 0], p[iy_min, 1]]])


def imshow_select(data, selector='lasso', bbox=False):
    """
    Display array as image with region selector.

    Parameters
    ----------
    data : array_like
        Array to display.
    selector : str
        Region selector. For `lasso`, use `LassoSelector`; for `rectangle`,
        use `RectangleSelector`.
    bbox : bool
        If True, only return array within rectangular bounding box of selected region.
        Otherwise, return array with same dimensions as `data` such that selected region
        contains the corresponding values from `data` and the remainder contains 0.
    Returns
    -------
    region : array_like
        Data for selected region.
    mask : array_like
        Boolean mask with same shape of `data` for selecting the returned region from `data`.
    """

    interactive = is_interactive()
    if not interactive:
        plt.ion()
    fig = plt.figure()
    ax = fig.gca()
    ax.imshow(data)

    x, y = np.meshgrid(np.arange(data.shape[1], dtype=int),
                       np.arange(data.shape[0], dtype=int))
    pix = np.vstack((x.flatten(), y.flatten())).T

    # Store data in dict value to permit overwriting by nested
    # functions in Python 2.7:
    selected = {}
    selected['data'] = np.zeros_like(data)
    selected['mask'] = np.tile(False, data.shape)

    def _onselect_lasso(verts):
        verts = np.array(verts)
        p = Path(verts)
        ind = p.contains_points(pix, radius=1)
        selected['data'].flat[ind] = data.flat[ind]
        selected['mask'].flat[ind] = True
        if bbox:
            b = path_bbox(verts)
            selected['data'] = selected['data'][int(min(b[:, 1])):int(max(b[:, 1])),
                               int(min(b[:, 0])):int(max(b[:, 0]))]

    def _onselect_rectangle(start, end):
        verts = np.array([[start.xdata, start.ydata],
                          [start.xdata, end.ydata],
                          [end.xdata, end.ydata],
                          [end.xdata, start.ydata]], int)
        p = Path(verts)
        ind = p.contains_points(pix, radius=1)
        selected['data'].flat[ind] = data.flat[ind]
        selected['mask'].flat[ind] = True
        if bbox:
            b = path_bbox(verts)
            selected['data'] = selected['data'][min(b[:, 1]):max(b[:, 1]),
                               min(b[:, 0]):max(b[:, 0])]

    name_to_selector = {'lasso': LassoSelector,
                        'rectangle': RectangleSelector}
    selector = name_to_selector[selector]
    onselect_dict = {LassoSelector: _onselect_lasso,
                     RectangleSelector: _onselect_rectangle}
    kwargs_dict = {LassoSelector: {},
                   RectangleSelector: {'interactive': True}}

    lasso = selector(ax, onselect_dict[selector], **kwargs_dict[selector])
    input('Press Enter when done')
    lasso.disconnect_events()
    if not interactive:
        plt.ioff()
    return selected['data'], selected['mask']


if __name__ == '__main__':
    from skimage.data import coins

    data = coins()
    selected, mask = imshow_select(data, 'lasso', True)
    plt.imsave('selected.png', selected)
    plt.imsave('mask.png', mask)