scikit-image

Форк
0
926 строк · 32.1 Кб
1
# cython: cdivision=True
2
# cython: boundscheck=False
3
# cython: nonecheck=False
4
# cython: wraparound=False
5

6

7
import numpy as np
8
cimport numpy as cnp
9
from . cimport safe_openmp as openmp
10
from .safe_openmp cimport have_openmp
11
from libc.stdlib cimport malloc, free
12
from libcpp.vector cimport vector
13

14
from skimage._shared.interpolation cimport round, fmax, fmin
15

16
from cython.parallel import prange
17
from ..color import rgb2gray
18
from ..transform import integral_image
19
import xml.etree.ElementTree as ET
20
from ._texture cimport _multiblock_lbp
21
import math
22

23
cnp.import_array()
24

25
# Struct for storing a single detection.
26
cdef struct Detection:
27

28
    int r
29
    int c
30
    int width
31
    int height
32

33

34
# Struct for storing cluster of rectangles that represent detections.
35
# As the rectangles are dynamically added, the sum of row, col positions,
36
# width and heights are stored with the count of rectangles that belong
37
# to this cluster. This way,  we don't have to store all the rectangles
38
# information as array and the average of all detections in a cluster
39
# can be easily computed in a constant time.
40
cdef struct DetectionsCluster:
41

42
    int r_sum
43
    int c_sum
44
    int width_sum
45
    int height_sum
46
    int count
47

48

49
# Struct for storing multi-block binary pattern position.
50
# Defines the parameters of multi-block binary pattern feature.
51
# Read more in skimage.feature.texture.multiblock_lbp.
52
cdef struct MBLBP:
53

54
    Py_ssize_t r
55
    Py_ssize_t c
56
    Py_ssize_t width
57
    Py_ssize_t height
58

59

60
# Struct for storing information about trained MBLBP feature.
61
# Feature_id contains an index to array where the parameters of MBLBP features
62
# are stored using MBLBP struct. Index is used because some stages in cascade
63
# can have repeating features. The lut_idx contains an index to a look-up table
64
# which gives, depending on the computed value of a feature, an answer whether
65
# an object is present in the current detection window. Based on the value of
66
# look-up table (0 or 1) positive(right) or negative(left) weight is added to
67
# the overall score of a stage.
68
cdef struct MBLBPStump:
69

70
    Py_ssize_t feature_id
71
    Py_ssize_t lut_idx
72
    cnp.float32_t left
73
    cnp.float32_t right
74

75

76
# Struct for storing a stage of classifier which itself consists of
77
# MBLBPStumps. It has the index that maps to the starting stump and amount of
78
# stumps that belong to a stage after this index. In each stage all the stumps
79
# are evaluated and their output values( `left` or `right` depending on the
80
# input) are summed up and compared to the threshold. If the value is higher
81
# than the threshold, the stage is passed and Cascade classifier goes to the
82
# next stage. If all the stages are passed, the object is predicted to be
83
# present in the input image patch.
84
cdef struct Stage:
85

86
    Py_ssize_t first_idx
87
    Py_ssize_t amount
88
    cnp.float32_t threshold
89

90

91
cdef vector[Detection] _group_detections(vector[Detection] detections,
92
                                         cnp.float32_t intersection_score_threshold=0.5,
93
                                         int min_neighbor_number=4):
94
    """Group similar detections into a single detection and eliminate weak
95
    (non-overlapping) detections.
96

97
    We assume that a true detection is characterized by a high number of
98
    overlapping detections. Such detections are isolated and gathered into
99
    one cluster. The average of each cluster is returned. Averaging means
100
    that the row and column positions of top left corners and the width
101
    and height parameters of each rectangle in a cluster are used to compute
102
    values of average rectangle that will represent cluster.
103

104
    Parameters
105
    ----------
106
    detections : vector[Detection]
107
        A cluster of detections.
108
    min_neighbor_number : int
109
        Minimum amount of intersecting detections in order for detection
110
        to be approved by the function.
111
    intersection_score_threshold : cnp.float32_t
112
        The minimum value of value of ratio
113
        (intersection area) / (small rectangle ratio) in order to merge
114
        two rectangles into one cluster.
115

116
    Returns
117
    -------
118
    output : vector[Detection]
119
        The grouped detections.
120
    """
121

122
    cdef:
123
        Detection mean_detection
124
        vector[DetectionsCluster] clusters
125
        Py_ssize_t nr_of_clusters
126
        Py_ssize_t current_detection_nr
127
        Py_ssize_t current_cluster_nr
128
        Py_ssize_t nr_of_detections = detections.size()
129
        Py_ssize_t best_cluster_nr
130
        bint new_cluster
131
        cnp.float32_t best_score
132
        cnp.float32_t intersection_score
133

134
    # Check if detections array is not empty.
135
    # Push first detection as first cluster.
136
    if nr_of_detections:
137
        clusters.push_back(cluster_from_detection(detections[0]))
138

139
    for current_detection_nr in range(1, nr_of_detections):
140

141
        best_score = intersection_score_threshold
142
        best_cluster_nr = 0
143
        new_cluster = True
144

145
        nr_of_clusters = clusters.size()
146

147
        for current_cluster_nr in range(nr_of_clusters):
148

149
            mean_detection = mean_detection_from_cluster(
150
                                    clusters[current_cluster_nr])
151

152
            intersection_score = rect_intersection_score(
153
                                        detections[current_detection_nr],
154
                                        mean_detection)
155

156
            if intersection_score > best_score:
157

158
                new_cluster = False
159
                best_cluster_nr = current_cluster_nr
160
                best_score = intersection_score
161

162
        if new_cluster:
163

164
            clusters.push_back(cluster_from_detection(
165
                                    detections[current_detection_nr]))
166
        else:
167

168
            clusters[best_cluster_nr] = update_cluster(
169
                                            clusters[best_cluster_nr],
170
                                            detections[current_detection_nr])
171

172
    clusters = threshold_clusters(clusters, min_neighbor_number)
173
    return get_mean_detections(clusters)
174

175

176
cdef DetectionsCluster update_cluster(DetectionsCluster cluster,
177
                                      Detection detection):
178
    """Updated the cluster by adding new detection.
179

180
    Updates the cluster by adding new detection to it. The added
181
    detection contributes to the mean value of the cluster.
182

183
    Parameters
184
    ----------
185
    cluster : DetectionsCluster
186
        A cluster of detections.
187
    detection : Detection
188
        The detection to be added to cluster.
189

190
    Returns
191
    -------
192
    updated_cluster : DetectionsCluster
193
        The updated cluster.
194
    """
195

196
    cdef DetectionsCluster updated_cluster = cluster
197

198
    updated_cluster.r_sum += detection.r
199
    updated_cluster.c_sum += detection.c
200
    updated_cluster.width_sum += detection.width
201
    updated_cluster.height_sum += detection.height
202
    updated_cluster.count += 1
203

204
    return updated_cluster
205

206

207
cdef Detection mean_detection_from_cluster(DetectionsCluster cluster):
208
    """Compute the mean detection from the cluster.
209

210
    Returns the mean detection computed from the all rectangles that
211
    belong to current cluster.
212

213
    Parameters
214
    ----------
215
    cluster : DetectionsCluster
216
        A cluster of detections.
217

218
    Returns
219
    -------
220
    mean : Detection
221
        The mean detection.
222
    """
223

224
    cdef Detection mean
225

226
    mean.r = cluster.r_sum / cluster.count
227
    mean.c = cluster.c_sum / cluster.count
228
    mean.width = cluster.width_sum / cluster.count
229
    mean.height = cluster.height_sum / cluster.count
230

231
    return mean
232

233

234
cdef DetectionsCluster cluster_from_detection(Detection detection):
235
    """Create a cluster from a single detection.
236

237
    Creates a cluster with count one and values that are taken from detection.
238

239
    Parameters
240
    ----------
241
    detection : Detection
242
        A single detection.
243

244
    Returns
245
    -------
246
    new_cluster : DetectionsCluster
247
        The cluster struct that was created from detection.
248
    """
249

250
    cdef DetectionsCluster new_cluster
251

252
    new_cluster.r_sum = detection.r
253
    new_cluster.c_sum = detection.c
254
    new_cluster.width_sum = detection.width
255
    new_cluster.height_sum = detection.height
256
    new_cluster.count = 1
257

258
    return new_cluster
259

260

261
cdef vector[DetectionsCluster] threshold_clusters(vector[DetectionsCluster] clusters,
262
                                                  int count_threshold):
263
    """Threshold clusters depending on the amount of rectangles in them.
264

265
    Only the clusters with the amount of rectangles greater than the threshold
266
    are left.
267

268
    Parameters
269
    ----------
270
    clusters : vector[DetectionsCluster]
271
        Array of rectangles clusters.
272
    count_threshold : int
273
        The threshold number of rectangles that is used.
274

275
    Returns
276
    -------
277
    output : vector[DetectionsCluster]
278
        The array of clusters that satisfy the threshold criteria.
279
    """
280

281
    cdef:
282
        Py_ssize_t clusters_amount
283
        Py_ssize_t current_cluster
284
        vector[DetectionsCluster] output
285

286
    clusters_amount = clusters.size()
287

288
    for current_cluster in range(clusters_amount):
289

290
        if clusters[current_cluster].count >= count_threshold:
291
            output.push_back(clusters[current_cluster])
292

293
    return output
294

295

296
cdef vector[Detection] get_mean_detections(vector[DetectionsCluster] clusters):
297
    """Computes the mean of each cluster of detections in the array.
298

299
    Each cluster is replaced with a single detection that represents
300
    the mean of the cluster, computed from the rectangles that belong
301
    to the cluster.
302

303
    Parameters
304
    ----------
305
    clusters : vector[DetectionsCluster]
306
        Array of rectangles clusters.
307

308
    Returns
309
    -------
310
    detections : vector[Detection]
311
        The array of mean detections. Each detection represent mean
312
        for one cluster.
313
    """
314

315
    cdef:
316
        Py_ssize_t current_cluster
317
        Py_ssize_t clusters_amount = clusters.size()
318
        vector[Detection] detections
319

320
    detections.resize(clusters_amount)
321

322
    for current_cluster in range(clusters_amount):
323
         detections[current_cluster] = mean_detection_from_cluster(clusters[current_cluster])
324

325
    return detections
326

327

328
cdef cnp.float32_t rect_intersection_area(Detection rect_a, Detection rect_b):
329
    """Computes the intersection area of two rectangles.
330

331

332
    Parameters
333
    ----------
334
    rect_a : Detection
335
        Struct of the first rectangle.
336
    rect_a : Detection
337
        Struct of the second rectangle.
338

339
    Returns
340
    -------
341
    result : cnp.float32_t
342
        The intersection score area.
343
    """
344

345
    cdef:
346
        Py_ssize_t r_a_1 = rect_a.r
347
        Py_ssize_t r_a_2 = rect_a.r + rect_a.height
348
        Py_ssize_t c_a_1 = rect_a.c
349
        Py_ssize_t c_a_2 = rect_a.c + rect_a.width
350

351
        Py_ssize_t r_b_1 = rect_b.r
352
        Py_ssize_t r_b_2 = rect_b.r + rect_b.height
353
        Py_ssize_t c_b_1 = rect_b.c
354
        Py_ssize_t c_b_2 = rect_b.c + rect_b.width
355

356
    return (fmax(0, fmin(c_a_2, c_b_2) - fmax(c_a_1, c_b_1)) *
357
            fmax(0, fmin(r_a_2, r_b_2) - fmax(r_a_1, r_b_1)))
358

359

360
cdef cnp.float32_t rect_intersection_score(Detection rect_a, Detection rect_b):
361
    """Computes the intersection score of two rectangles.
362

363
    The score is computed by dividing the intersection area of rectangles
364
    by the area of the rectangle with the smallest area.
365

366
    Parameters
367
    ----------
368
    rect_a : Detection
369
        Struct of the first rectangle.
370
    rect_a : Detection
371
        Struct of the second rectangle.
372

373
    Returns
374
    -------
375
    result : cnp.float32_t
376
        The intersection score. The number in the interval ``[0, 1]``.
377
        1 means rectangles fully intersect, 0 means they don't.
378
    """
379

380
    cdef:
381
        cnp.float32_t intersection_area
382
        cnp.float32_t smaller_area
383
        cnp.float32_t area_a = rect_a.height * rect_a.width
384
        cnp.float32_t area_b = rect_b.height * rect_b.width
385

386
    intersection_area = rect_intersection_area(rect_a, rect_b)
387

388
    smaller_area = area_a if area_b > area_a else area_b
389

390
    return intersection_area / smaller_area
391

392

393
cdef class Cascade:
394
    """Class for cascade of classifiers that is used for object detection.
395

396
    The main idea behind cascade of classifiers is to create classifiers
397
    of medium accuracy and ensemble them into one strong classifier
398
    instead of just creating a strong one. The second advantage of cascade
399
    classifier is that easy examples can be classified only by evaluating
400
    some of the classifiers in the cascade, making the process much faster
401
    than the process of evaluating a one strong classifier.
402

403
    Attributes
404
    ----------
405
    eps : cnp.float32_t
406
        Accuracy parameter. Increasing it, makes the classifier detect less
407
        false positives but at the same time the false negative score increases.
408
    stages_number : Py_ssize_t
409
        Amount of stages in a cascade. Each cascade consists of stumps i.e.
410
        trained features.
411
    stumps_number : Py_ssize_t
412
        The overall amount of stumps in all the stages of cascade.
413
    features_number : Py_ssize_t
414
        The overall amount of different features used by cascade.
415
        Two stumps can use the same features but has different trained
416
        values.
417
    window_width : Py_ssize_t
418
        The width of a detection window that is used. Objects smaller than
419
        this window can't be detected.
420
    window_height : Py_ssize_t
421
        The height of a detection window.
422
    stages : Stage*
423
        A pointer to the C array that stores stages information using a
424
        Stage struct.
425
    features : MBLBP*
426
        A pointer to the C array that stores MBLBP features using an MBLBP
427
        struct.
428
    LUTs : cnp.uint32_t*
429
        A pointer to the C array with look-up tables that are used by trained
430
        MBLBP features (MBLBPStumps) to evaluate a particular region.
431

432
    Notes
433
    -----
434
    The cascade approach was first described by Viola and Jones [1]_, [2]_,
435
    although these initial publications used a set of Haar-like features. This
436
    implementation instead uses multi-scale block local binary pattern (MB-LBP)
437
    features [3]_.
438

439
    References
440
    ----------
441
    .. [1] Viola, P. and Jones, M. "Rapid object detection using a boosted
442
           cascade of simple features," In: Proceedings of the 2001 IEEE
443
           Computer Society Conference on Computer Vision and Pattern
444
           Recognition. CVPR 2001, pp. I-I.
445
           :DOI:`10.1109/CVPR.2001.990517`
446
    .. [2] Viola, P. and Jones, M.J, "Robust Real-Time Face Detection",
447
           International Journal of Computer Vision 57, 137–154 (2004).
448
           :DOI:`10.1023/B:VISI.0000013087.49260.fb`
449
    .. [3] Liao, S. et al. Learning Multi-scale Block Local Binary Patterns for
450
           Face Recognition. International Conference on Biometrics (ICB),
451
           2007, pp. 828-837. In: Lecture Notes in Computer Science, vol 4642.
452
           Springer, Berlin, Heidelberg.
453
           :DOI:`10.1007/978-3-540-74549-5_87`
454
    """
455

456
    cdef:
457
        public cnp.float32_t eps
458
        public Py_ssize_t stages_number
459
        public Py_ssize_t stumps_number
460
        public Py_ssize_t features_number
461
        public Py_ssize_t window_width
462
        public Py_ssize_t window_height
463
        Stage* stages
464
        MBLBPStump* stumps
465
        MBLBP* features
466
        cnp.uint32_t* LUTs
467

468
    def __dealloc__(self):
469

470
        # Free the memory that was used for c-arrays.
471
        free(self.stages)
472
        free(self.stumps)
473
        free(self.features)
474
        free(self.LUTs)
475

476
    def __init__(self, xml_file, eps=1e-5):
477
        """Initialize cascade classifier.
478

479
        Parameters
480
        ----------
481
        xml_file : file's path or file's object
482
            A file in a OpenCv format from which all the cascade classifier's
483
            parameters are loaded.
484
        eps : cnp.float32_t
485
            Accuracy parameter. Increasing it, makes the classifier
486
            detect less false positives but at the same time the false
487
            negative score increases.
488

489
        """
490

491
        self._load_xml(xml_file, eps)
492

493
    cdef bint classify(self, cnp.float32_t[:, ::1] int_img, Py_ssize_t row,
494
                       Py_ssize_t col, cnp.float32_t scale) noexcept nogil:
495
        """Classify the provided image patch i.e. check if the classifier
496
        detects an object in the given image patch.
497

498
        The function takes the original window size that is stored in the
499
        trained file, scales it and places in the specified part of the
500
        provided image, carries out classification and gives a binary result.
501

502
        Parameters
503
        ----------
504
        int_img : cnp.float32_t[:, ::1]
505
            Memory-view to integral image.
506
        row : Py_ssize_t
507
            Row coordinate of the rectangle in the given image to classify.
508
            Top left corner of window.
509
        col : Py_ssize_t
510
            Column coordinate of the rectangle in the given image to classify.
511
            Top left corner of window.
512
        scale : cnp.float32_t
513
            The scale by which the search window is multiplied.
514
            After multiplication the result is rounded to the lowest integer.
515

516
        Returns
517
        -------
518
        result : int
519
            The binary output that takes only 0 or 1. Gives 1 if the classifier
520
            detects the object in specified region and 0 otherwise.
521
        """
522

523
        cdef:
524
            cnp.float32_t stage_points
525
            int lbp_code
526
            int bit
527
            Py_ssize_t stage_number
528
            Py_ssize_t weak_classifier_number
529
            Py_ssize_t first_stump_idx
530
            Py_ssize_t lut_idx
531
            Py_ssize_t r, c, width, height
532
            Stage current_stage
533
            MBLBPStump current_stump
534
            MBLBP current_feature
535

536

537
        for stage_number in range(self.stages_number):
538

539
            current_stage = self.stages[stage_number]
540
            first_stump_idx = current_stage.first_idx
541
            stage_points = 0
542

543
            for weak_classifier_number in range(current_stage.amount):
544

545
                current_stump = self.stumps[first_stump_idx +
546
                                            weak_classifier_number]
547

548
                current_feature = self.features[current_stump.feature_id]
549

550
                r = <Py_ssize_t>(current_feature.r * scale)
551
                c = <Py_ssize_t>(current_feature.c * scale)
552
                width = <Py_ssize_t>(current_feature.width * scale)
553
                height = <Py_ssize_t>(current_feature.height * scale)
554

555

556
                lbp_code = _multiblock_lbp(int_img, row + r, col + c,
557
                                           width, height)
558

559
                lut_idx = current_stump.lut_idx
560

561
                bit = (self.LUTs[lut_idx + (lbp_code >> 5)] >> (lbp_code & 31)) & 1
562

563
                stage_points += current_stump.left if bit else current_stump.right
564

565
            if stage_points < (current_stage.threshold - self.eps):
566

567
                return False
568

569
        return True
570

571
    def _get_valid_scale_factors(self, min_size, max_size, scale_step):
572
        """Get the valid scale multipliers for the original window size.
573

574
        The function takes the minimal size of window and maximum size of
575
        window as interval and finds all the multipliers that will give the
576
        windows which sizes will be not less than the min_size and not bigger
577
        than the max_size.
578

579
        Parameters
580
        ----------
581
        min_size : tuple (int, int)
582
            Minimum size of window for which to search the scale factor.
583
        max_size : tuple (int, int)
584
            Maximum size of window for which to search the scale factor.
585
        scale_step : cnp.float32_t
586
            The scale by which the search window is multiplied
587
            on each iteration.
588

589
        Returns
590
        -------
591
        scale_factors : 1-D cnp.float32_ts ndarray
592
            The scale factors that give the window sizes that are in the
593
            specified interval after multiplying the search window.
594
        """
595

596
        current_size = np.array((self.window_height, self.window_width))
597
        min_size = np.array(min_size, dtype=np.float32)
598
        max_size = np.array(max_size, dtype=np.float32)
599

600
        row_power_max = math.log(max_size[0]/current_size[0], scale_step)
601
        col_power_max = math.log(max_size[1]/current_size[1], scale_step)
602

603
        row_power_min = math.log(min_size[0]/current_size[0], scale_step)
604
        col_power_min = math.log(min_size[1]/current_size[1], scale_step)
605

606
        mn = max(row_power_min, col_power_min, 0)
607
        mx = min(row_power_max, col_power_max)
608

609
        powers = np.arange(mn, mx)
610

611
        scale_factors = np.power(scale_step, powers, dtype=np.float32)
612

613
        return scale_factors
614

615
    def _get_contiguous_integral_image(self, img):
616
        """Get a c-contiguous array that represents the integral image.
617

618
        The function converts the input image into the integral image in
619
        a format that is suitable for work of internal functions of
620
        the cascade classifier class. The function converts the image
621
        to gray-scale float representation, computes the integral image
622
        and makes it c-contiguous.
623

624
        Parameters
625
        ----------
626
        img : 2-D or 3-D ndarray
627
            Ndarray that represents the input image.
628

629
        Returns
630
        -------
631
        int_img : 2-D floats ndarray
632
            C-contiguous integral image of the input image.
633
        """
634
        if len(img.shape) > 2:
635
            img = rgb2gray(img)
636
        int_img = integral_image(img)
637
        int_img = np.ascontiguousarray(int_img, dtype=np.float32)
638

639
        return int_img
640

641

642
    def detect_multi_scale(self, img, cnp.float32_t scale_factor,
643
                           cnp.float32_t step_ratio, min_size, max_size,
644
                           min_neighbor_number=4,
645
                           intersection_score_threshold=0.5):
646
        """Search for the object on multiple scales of input image.
647

648
        The function takes the input image, the scale factor by which the
649
        searching window is multiplied on each step, minimum window size
650
        and maximum window size that specify the interval for the search
651
        windows that are applied to the input image to detect objects.
652

653
        Parameters
654
        ----------
655
        img : 2-D or 3-D ndarray
656
            Ndarray that represents the input image.
657
        scale_factor : cnp.float32_t
658
            The scale by which searching window is multiplied on each step.
659
        step_ratio : cnp.float32_t
660
            The ratio by which the search step in multiplied on each scale
661
            of the image. 1 represents the exaustive search and usually is
662
            slow. By setting this parameter to higher values the results will
663
            be worse but the computation will be much faster. Usually, values
664
            in the interval [1, 1.5] give good results.
665
        min_size : tuple (int, int)
666
            Minimum size of the search window.
667
        max_size : tuple (int, int)
668
            Maximum size of the search window.
669
        min_neighbor_number : int
670
            Minimum amount of intersecting detections in order for detection
671
            to be approved by the function.
672
        intersection_score_threshold : cnp.float32_t
673
            The minimum value of value of ratio
674
            (intersection area) / (small rectangle ratio) in order to merge
675
            two detections into one.
676

677
        Returns
678
        -------
679
        output : list of dicts
680
            Dict have form {'r': int, 'c': int, 'width': int, 'height': int},
681
            where 'r' represents row position of top left corner of detected
682
            window, 'c' - col position, 'width' - width of detected window,
683
            'height' - height of detected window.
684
        """
685

686
        cdef:
687
            Py_ssize_t max_row
688
            Py_ssize_t max_col
689
            Py_ssize_t current_height
690
            Py_ssize_t current_width
691
            Py_ssize_t current_row
692
            Py_ssize_t current_col
693
            Py_ssize_t current_step
694
            Py_ssize_t number_of_scales
695
            Py_ssize_t img_height
696
            Py_ssize_t img_width
697
            Py_ssize_t scale_number
698
            Py_ssize_t window_height = self.window_height
699
            Py_ssize_t window_width = self.window_width
700
            int result
701
            cnp.float32_t[::1] scale_factors
702
            cnp.float32_t[:, ::1] int_img
703
            cnp.float32_t current_scale_factor
704
            vector[Detection] output
705
            Detection new_detection
706

707
        int_img = self._get_contiguous_integral_image(img)
708
        img_height = int_img.shape[0]
709
        img_width = int_img.shape[1]
710

711
        scale_factors = self._get_valid_scale_factors(min_size,
712
                                                      max_size, scale_factor)
713
        number_of_scales = scale_factors.shape[0]
714

715
        # Initialize lock to enable thread-safe writes to the array
716
        # in concurrent loop.
717
        cdef openmp.omp_lock_t mylock
718

719
        if have_openmp:
720
            openmp.omp_init_lock(&mylock)
721

722

723
        # As the amount of work between the threads is not equal we
724
        # use `dynamic` schedule which enables them to use computing
725
        # power on demand.
726
        for scale_number in prange(0, number_of_scales,
727
                                   schedule='dynamic', nogil=True):
728

729
            current_scale_factor = scale_factors[scale_number]
730
            current_step = <Py_ssize_t>round(current_scale_factor * step_ratio)
731
            current_height = <Py_ssize_t>(window_height * current_scale_factor)
732
            current_width = <Py_ssize_t>(window_width * current_scale_factor)
733
            max_row = img_height - current_height
734
            max_col = img_width - current_width
735

736
            # Check if scaled detection window fits in image.
737
            if (max_row < 0) or (max_col < 0):
738
                continue
739

740
            current_row = 0
741
            current_col = 0
742

743
            while current_row < max_row:
744
                while current_col < max_col:
745

746
                    result = self.classify(int_img, current_row,
747
                                           current_col,
748
                                           scale_factors[scale_number])
749

750
                    if result:
751

752
                        new_detection.r = current_row
753
                        new_detection.c = current_col
754
                        new_detection.width = current_width
755
                        new_detection.height = current_height
756

757
                        if have_openmp:
758
                            openmp.omp_set_lock(&mylock)
759

760
                        output.push_back(new_detection)
761

762
                        if have_openmp:
763
                            openmp.omp_unset_lock(&mylock)
764

765
                    current_col = current_col + current_step
766

767
                current_row = current_row + current_step
768
                current_col = 0
769

770
        if have_openmp:
771
            openmp.omp_destroy_lock(&mylock)
772

773
        return list(_group_detections(output, intersection_score_threshold,
774
                                      min_neighbor_number))
775

776
    def _load_xml(self, xml_file, eps=1e-5):
777
        """Load the parameters of cascade classifier into the class.
778

779
        The function takes the file with the parameters that represent
780
        trained cascade classifier and loads them into class for later
781
        use.
782

783
        Parameters
784
        ----------
785
        xml_file : filename or file object
786
            File that contains the cascade classifier.
787
        eps : cnp.float32_t
788
            Accuracy parameter. Increasing it, makes the classifier
789
            detect less false positives but at the same time the false
790
            negative score increases.
791

792
        """
793

794
        cdef:
795
            Stage* stages_carr
796
            MBLBPStump* stumps_carr
797
            MBLBP* features_carr
798
            cnp.uint32_t* LUTs_carr
799

800
            cnp.float32_t stage_threshold
801

802
            Py_ssize_t stage_number
803
            Py_ssize_t stages_number
804
            Py_ssize_t window_height
805
            Py_ssize_t window_width
806

807
            Py_ssize_t weak_classifiers_amount
808
            Py_ssize_t weak_classifier_number
809

810
            Py_ssize_t feature_number
811
            Py_ssize_t features_number
812
            Py_ssize_t stump_lut_idx
813
            Py_ssize_t stump_idx
814
            Py_ssize_t i
815

816
            cnp.uint32_t[::1] lut
817

818
            MBLBP new_feature
819
            MBLBPStump new_stump
820
            Stage new_stage
821

822
        tree = ET.parse(xml_file)
823

824
        # Load entities.
825
        features = tree.find('.//features')
826
        stages = tree.find('.//stages')
827

828
        # Get the respective amounts.
829
        stages_number = int(tree.find('.//stageNum').text)
830
        window_height = int(tree.find('.//height').text)
831
        window_width = int(tree.find('.//width').text)
832
        features_number = len(features)
833

834
        # Count the stumps.
835
        stumps_number = 0
836
        for stage_number in range(stages_number):
837
            current_stage = stages[stage_number]
838
            weak_classifiers_amount = int(current_stage.find('maxWeakCount').text)
839
            stumps_number += weak_classifiers_amount
840

841
        # Allocate memory for data.
842
        features_carr = <MBLBP*>malloc(features_number * sizeof(MBLBP))
843
        stumps_carr = <MBLBPStump*>malloc(stumps_number * sizeof(MBLBPStump))
844
        stages_carr = <Stage*>malloc(stages_number*sizeof(Stage))
845
        # Each look-up table consists of 8 u-int numbers.
846
        LUTs_carr = <cnp.uint32_t*>malloc(8 * stumps_number *
847
                                          sizeof(cnp.uint32_t))
848

849
        # Check if memory was allocated.
850
        if not (features_carr and stumps_carr and stages_carr and LUTs_carr):
851
            free(features_carr)
852
            free(stumps_carr)
853
            free(stages_carr)
854
            free(LUTs_carr)
855
            raise MemoryError("Failed to allocate memory while parsing XML.")
856

857
        # Parse and load features in memory.
858
        for feature_number in range(features_number):
859
            params = features[feature_number][0].text.split()
860
            # list() is for Python3 fix here
861
            params = list(map(lambda x: int(x), params))
862
            new_feature = MBLBP(params[1], params[0], params[2], params[3])
863
            features_carr[feature_number] = new_feature
864

865
        stump_lut_idx = 0
866
        stump_idx = 0
867

868
        # Parse and load stumps, stages.
869
        for stage_number in range(stages_number):
870

871
            current_stage = stages[stage_number]
872

873
            # Parse and load current stage.
874
            stage_threshold = float(current_stage.find('stageThreshold').text)
875
            weak_classifiers_amount = int(current_stage.find('maxWeakCount').text)
876
            new_stage = Stage(stump_idx, weak_classifiers_amount,
877
                              stage_threshold)
878
            stages_carr[stage_number] = new_stage
879

880
            weak_classifiers = current_stage.find('weakClassifiers')
881

882
            for weak_classifier_number in range(weak_classifiers_amount):
883

884
                current_weak_classifier = weak_classifiers[weak_classifier_number]
885

886
                # Stump's leaf values. First negative if image is probably not
887
                # a face. Second positive if image is probably a face.
888
                leaf_values = current_weak_classifier.find('leafValues').text
889
                # list() is for Python3 fix here
890
                leaf_values = list(map(lambda x: float(x), leaf_values.split()))
891

892
                # Extract the elements only starting from second.
893
                # First two are useless
894
                internal_nodes = current_weak_classifier.find('internalNodes')
895
                internal_nodes = internal_nodes.text.split()[2:]
896

897
                # Extract the feature number and respective parameters.
898
                # The MBLBP position and size.
899
                feature_number = int(internal_nodes[0])
900
                # list() is for Python3 fix here
901
                lut_array = list(map(lambda x: int(x), internal_nodes[1:]))
902
                # Cast via astype to avoid warning about integer wraparound.
903
                # see: https://github.com/scikit-image/scikit-image/issues/6638
904
                lut = np.asarray(lut_array).astype(np.uint32)
905

906
                # Copy array to the main LUT array
907
                for i in range(8):
908
                    LUTs_carr[stump_lut_idx + i] = lut[i]
909

910
                new_stump = MBLBPStump(feature_number, stump_lut_idx,
911
                                       leaf_values[0], leaf_values[1])
912
                stumps_carr[stump_idx] = new_stump
913

914
                stump_lut_idx += 8
915
                stump_idx += 1
916

917
        self.eps = eps
918
        self.window_height = window_height
919
        self.window_width = window_width
920
        self.features = features_carr
921
        self.stumps = stumps_carr
922
        self.stages = stages_carr
923
        self.LUTs = LUTs_carr
924
        self.stages_number = stages_number
925
        self.features_number = features_number
926
        self.stumps_number = stumps_number
927

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.