paddlenlp

Форк
0
668 строк · 23.3 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import math
16
import warnings
17

18
import numpy as np
19
import paddle
20
from paddle.metric import Accuracy, Metric, Precision, Recall
21

22
__all__ = ["Accuracy", "AccuracyAndF1", "Mcc", "PearsonAndSpearman", "MultiLabelsMetric"]
23

24

25
class AccuracyAndF1(Metric):
26
    """
27
    This class encapsulates Accuracy, Precision, Recall and F1 metric logic,
28
    and `accumulate` function returns accuracy, precision, recall and f1.
29
    The overview of all metrics could be seen at the document of `paddle.metric
30
    <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/metric/Overview_cn.html>`_
31
    for details.
32

33
    Args:
34
        topk (int or tuple(int), optional):
35
            Number of top elements to look at for computing accuracy.
36
            Defaults to (1,).
37
        pos_label (int, optional): The positive label for calculating precision
38
            and recall.
39
            Defaults to 1.
40
        name (str, optional):
41
            String name of the metric instance. Defaults to 'acc_and_f1'.
42

43
    Example:
44

45
        .. code-block::
46

47
            import paddle
48
            from paddlenlp.metrics import AccuracyAndF1
49

50
            x = paddle.to_tensor([[0.1, 0.9], [0.5, 0.5], [0.6, 0.4], [0.7, 0.3]])
51
            y = paddle.to_tensor([[1], [0], [1], [1]])
52

53
            m = AccuracyAndF1()
54
            correct = m.compute(x, y)
55
            m.update(correct)
56
            res = m.accumulate()
57
            print(res) # (0.5, 0.5, 0.3333333333333333, 0.4, 0.45)
58

59
    """
60

61
    def __init__(self, topk=(1,), pos_label=1, name="acc_and_f1", *args, **kwargs):
62
        super(AccuracyAndF1, self).__init__(*args, **kwargs)
63
        self.topk = topk
64
        self.pos_label = pos_label
65
        self._name = name
66
        self.acc = Accuracy(self.topk, *args, **kwargs)
67
        self.precision = Precision(*args, **kwargs)
68
        self.recall = Recall(*args, **kwargs)
69
        self.reset()
70

71
    def compute(self, pred, label, *args):
72
        """
73
        Accepts network's output and the labels, and calculates the top-k
74
        (maximum value in topk) indices for accuracy.
75

76
        Args:
77
            pred (Tensor):
78
                Predicted tensor, and its dtype is float32 or float64, and
79
                has a shape of [batch_size, num_classes].
80
            label (Tensor):
81
                The ground truth tensor, and its dtype is int64, and has a
82
                shape of [batch_size, 1] or [batch_size, num_classes] in one
83
                hot representation.
84

85
        Returns:
86
            Tensor: Correct mask, each element indicates whether the prediction
87
            equals to the label. Its' a tensor with a data type of float32 and
88
            has a shape of [batch_size, topk].
89

90
        """
91
        self.label = label
92
        self.preds_pos = paddle.nn.functional.softmax(pred)[:, self.pos_label]
93
        return self.acc.compute(pred, label)
94

95
    def update(self, correct, *args):
96
        """
97
        Updates the metrics states (accuracy, precision and recall), in order to
98
        calculate accumulated accuracy, precision and recall of all instances.
99

100
        Args:
101
            correct (Tensor):
102
                Correct mask for calculating accuracy, and it's a tensor with
103
                shape [batch_size, topk] and has a dtype of
104
                float32.
105

106
        """
107
        self.acc.update(correct)
108
        self.precision.update(self.preds_pos, self.label)
109
        self.recall.update(self.preds_pos, self.label)
110

111
    def accumulate(self):
112
        """
113
        Calculates and returns the accumulated metric.
114

115
        Returns:
116
            tuple: The accumulated metric. A tuple of shape (acc, precision,
117
            recall, f1, average_of_acc_and_f1)
118

119
            With the fields:
120

121
            - `acc` (numpy.float64):
122
                The accumulated accuracy.
123
            - `precision` (numpy.float64):
124
                The accumulated precision.
125
            - `recall` (numpy.float64):
126
                The accumulated recall.
127
            - `f1` (numpy.float64):
128
                The accumulated f1.
129
            - `average_of_acc_and_f1` (numpy.float64):
130
                The average of accumulated accuracy and f1.
131

132
        """
133
        acc = self.acc.accumulate()
134
        precision = self.precision.accumulate()
135
        recall = self.recall.accumulate()
136
        if precision == 0.0 or recall == 0.0:
137
            f1 = 0.0
138
        else:
139
            # 1/f1 = 1/2 * (1/precision + 1/recall)
140
            f1 = (2 * precision * recall) / (precision + recall)
141
        return (
142
            acc,
143
            precision,
144
            recall,
145
            f1,
146
            (acc + f1) / 2,
147
        )
148

149
    def reset(self):
150
        """
151
        Resets all metric states.
152
        """
153
        self.acc.reset()
154
        self.precision.reset()
155
        self.recall.reset()
156
        self.label = None
157
        self.preds_pos = None
158

159
    def name(self):
160
        """
161
        Returns name of the metric instance.
162

163
        Returns:
164
           str: The name of the metric instance.
165

166
        """
167
        return self._name
168

169

170
class Mcc(Metric):
171
    """
172
    This class calculates `Matthews correlation coefficient <https://en.wikipedia.org/wiki/Matthews_correlation_coefficient>`_ .
173

174
    Args:
175
        name (str, optional):
176
            String name of the metric instance. Defaults to 'mcc'.
177

178
    Example:
179

180
        .. code-block::
181

182
            import paddle
183
            from paddlenlp.metrics import Mcc
184

185
            x = paddle.to_tensor([[-0.1, 0.12], [-0.23, 0.23], [-0.32, 0.21], [-0.13, 0.23]])
186
            y = paddle.to_tensor([[1], [0], [1], [1]])
187

188
            m = Mcc()
189
            (preds, label) = m.compute(x, y)
190
            m.update((preds, label))
191
            res = m.accumulate()
192
            print(res) # (0.0,)
193

194
    """
195

196
    def __init__(self, name="mcc", *args, **kwargs):
197
        super(Mcc, self).__init__(*args, **kwargs)
198
        self._name = name
199
        self.tp = 0  # true positive
200
        self.fp = 0  # false positive
201
        self.tn = 0  # true negative
202
        self.fn = 0  # false negative
203

204
    def compute(self, pred, label, *args):
205
        """
206
        Processes the pred tensor, and returns the indices of the maximum of each
207
        sample.
208

209
        Args:
210
            pred (Tensor):
211
                The predicted value is a Tensor with dtype float32 or float64.
212
                Shape is [batch_size, 1].
213
            label (Tensor):
214
                The ground truth value is Tensor with dtype int64, and its
215
                shape is [batch_size, 1].
216

217
        Returns:
218
            tuple: A tuple of preds and label. Each shape is
219
            [batch_size, 1], with dtype float32 or float64.
220

221
        """
222
        preds = paddle.argsort(pred, descending=True)[:, :1]
223
        return (preds, label)
224

225
    def update(self, preds_and_labels):
226
        """
227
        Calculates states, i.e. the number of true positive, false positive,
228
        true negative and false negative samples.
229

230
        Args:
231
            preds_and_labels (tuple[Tensor]):
232
                Tuple of predicted value and the ground truth label, with dtype
233
                float32 or float64. Each shape is [batch_size, 1].
234

235
        """
236
        preds = preds_and_labels[0]
237
        labels = preds_and_labels[1]
238
        if isinstance(preds, paddle.Tensor):
239
            preds = preds.numpy()
240
        if isinstance(labels, paddle.Tensor):
241
            labels = labels.numpy().reshape(-1, 1)
242
        sample_num = labels.shape[0]
243
        for i in range(sample_num):
244
            pred = preds[i]
245
            label = labels[i]
246
            if pred == 1:
247
                if pred == label:
248
                    self.tp += 1
249
                else:
250
                    self.fp += 1
251
            else:
252
                if pred == label:
253
                    self.tn += 1
254
                else:
255
                    self.fn += 1
256

257
    def accumulate(self):
258
        """
259
        Calculates and returns the accumulated metric.
260

261
        Returns:
262
            tuple: Returns the accumulated metric, a tuple of shape (mcc,), `mcc` is the accumulated mcc and its data
263
            type is float64.
264

265
        """
266
        if self.tp == 0 or self.fp == 0 or self.tn == 0 or self.fn == 0:
267
            mcc = 0.0
268
        else:
269
            # mcc = (tp*tn-fp*fn)/ sqrt(tp+fp)(tp+fn)(tn+fp)(tn+fn))
270
            mcc = (self.tp * self.tn - self.fp * self.fn) / math.sqrt(
271
                (self.tp + self.fp) * (self.tp + self.fn) * (self.tn + self.fp) * (self.tn + self.fn)
272
            )
273
        return (mcc,)
274

275
    def reset(self):
276
        """
277
        Resets all metric states.
278
        """
279
        self.tp = 0  # true positive
280
        self.fp = 0  # false positive
281
        self.tn = 0  # true negative
282
        self.fn = 0  # false negative
283

284
    def name(self):
285
        """
286
        Returns name of the metric instance.
287

288
        Returns:
289
            str: The name of the metric instance.
290

291
        """
292
        return self._name
293

294

295
class PearsonAndSpearman(Metric):
296
    """
297
    The class calculates `Pearson correlation coefficient <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_
298
    and `Spearman's rank correlation coefficient <https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_ .
299

300

301
    Args:
302
        name (str, optional):
303
            String name of the metric instance. Defaults to 'pearson_and_spearman'.
304

305
    Example:
306

307
        .. code-block::
308

309
            import paddle
310
            from paddlenlp.metrics import PearsonAndSpearman
311

312
            x = paddle.to_tensor([[0.1], [1.0], [2.4], [0.9]])
313
            y = paddle.to_tensor([[0.0], [1.0], [2.9], [1.0]])
314

315
            m = PearsonAndSpearman()
316
            m.update((x, y))
317
            res = m.accumulate()
318
            print(res) # (0.9985229081857804, 1.0, 0.9992614540928901)
319

320
    """
321

322
    def __init__(self, name="pearson_and_spearman", *args, **kwargs):
323
        super(PearsonAndSpearman, self).__init__(*args, **kwargs)
324
        self._name = name
325
        self.preds = []
326
        self.labels = []
327

328
    def update(self, preds_and_labels):
329
        """
330
        Ensures the type of preds and labels is numpy.ndarray and reshapes them
331
        into [-1, 1].
332

333
        Args:
334
            preds_and_labels (tuple[Tensor] or list[Tensor]):
335
                Tuple or list of predicted value and the ground truth label.
336
                Its data type should be float32 or float64 and its shape is [batch_size, d0, ..., dN].
337

338
        """
339
        preds = preds_and_labels[0]
340
        labels = preds_and_labels[1]
341
        if isinstance(preds, paddle.Tensor):
342
            preds = preds.numpy()
343
        if isinstance(labels, paddle.Tensor):
344
            labels = labels.numpy()
345
        preds = np.squeeze(preds.reshape(-1, 1)).tolist()
346
        labels = np.squeeze(labels.reshape(-1, 1)).tolist()
347
        self.preds.append(preds)
348
        self.labels.append(labels)
349

350
    def accumulate(self):
351
        """
352
        Calculates and returns the accumulated metric.
353

354
        Returns:
355
            tuple: Returns the accumulated metric, a tuple of (pearson, spearman,
356
            the_average_of_pearson_and_spearman).
357

358
            With the fields:
359

360
            - `pearson` (numpy.float64):
361
                The accumulated pearson.
362

363
            - `spearman` (numpy.float64):
364
                The accumulated spearman.
365

366
            - `the_average_of_pearson_and_spearman` (numpy.float64):
367
                The average of accumulated pearson and spearman correlation
368
                coefficient.
369

370
        """
371
        preds = [item for sublist in self.preds for item in sublist]
372
        labels = [item for sublist in self.labels for item in sublist]
373
        pearson = self.pearson(preds, labels)
374
        spearman = self.spearman(preds, labels)
375
        return (
376
            pearson,
377
            spearman,
378
            (pearson + spearman) / 2,
379
        )
380

381
    def pearson(self, preds, labels):
382
        n = len(preds)
383
        # simple sums
384
        sum1 = sum(float(preds[i]) for i in range(n))
385
        sum2 = sum(float(labels[i]) for i in range(n))
386
        # sum up the squares
387
        sum1_pow = sum([pow(v, 2.0) for v in preds])
388
        sum2_pow = sum([pow(v, 2.0) for v in labels])
389
        # sum up the products
390
        p_sum = sum([preds[i] * labels[i] for i in range(n)])
391

392
        numerator = p_sum - (sum1 * sum2 / n)
393
        denominator = math.sqrt((sum1_pow - pow(sum1, 2) / n) * (sum2_pow - pow(sum2, 2) / n))
394
        if denominator == 0:
395
            return 0.0
396
        return numerator / denominator
397

398
    def spearman(self, preds, labels):
399
        preds_rank = self.get_rank(preds)
400
        labels_rank = self.get_rank(labels)
401

402
        total = 0
403
        n = len(preds)
404
        for i in range(n):
405
            total += pow((preds_rank[i] - labels_rank[i]), 2)
406
        spearman = 1 - float(6 * total) / (n * (pow(n, 2) - 1))
407
        return spearman
408

409
    def get_rank(self, raw_list):
410
        x = np.array(raw_list)
411
        r_x = np.empty(x.shape, dtype=int)
412
        y = np.argsort(-x)
413
        for i, k in enumerate(y):
414
            r_x[k] = i + 1
415
        return r_x
416

417
    def reset(self):
418
        """
419
        Resets all metric states.
420
        """
421
        self.preds = []
422
        self.labels = []
423

424
    def name(self):
425
        """
426
        Returns name of the metric instance.
427

428
        Returns:
429
           str: The name of the metric instance.
430

431
        """
432
        return self._name
433

434

435
class MultiLabelsMetric(Metric):
436
    """
437
    This class encapsulates Accuracy, Precision, Recall and F1 metric logic in
438
    multi-labels setting (also the binary setting).
439
    Some codes are taken and modified from sklearn.metrics .
440

441
    Args:
442
        num_labels (int)
443
            The total number of labels which is usually the number of classes
444
        name (str, optional):
445
            String name of the metric instance. Defaults to 'multi_labels_metric'.
446

447
    Example:
448

449
        .. code-block::
450

451
            import paddle
452
            from paddlenlp.metrics import MultiLabelsMetric
453

454
            x = paddle.to_tensor([[0.1, 0.2, 0.9], [0.5, 0.8, 0.5], [0.6, 1.5, 0.4], [2.8, 0.7, 0.3]])
455
            y = paddle.to_tensor([[2], [1], [2], [1]])
456

457
            m = MultiLabelsMetric(num_labels=3)
458
            args = m.compute(x, y)
459
            m.update(args)
460

461
            result1 = m.accumulate(average=None)
462
            # (array([0.0, 0.5, 1.0]), array([0.0, 0.5, 0.5]), array([0.0, 0.5, 0.66666667]))
463
            result2 = m.accumulate(average='binary', pos_label=0)
464
            # (0.0, 0.0, 0.0)
465
            result3 = m.accumulate(average='binary', pos_label=1)
466
            # (0.5, 0.5, 0.5)
467
            result4 = m.accumulate(average='binary', pos_label=2)
468
            # (1.0, 0.5, 0.6666666666666666)
469
            result5 = m.accumulate(average='micro')
470
            # (0.5, 0.5, 0.5)
471
            result6 = m.accumulate(average='macro')
472
            # (0.5, 0.3333333333333333, 0.38888888888888884)
473
            result7 = m.accumulate(average='weighted')
474
            # (0.75, 0.5, 0.5833333333333333)
475

476
    Note: When zero_division is encountered (details as followed), the corresponding metrics will be set to 0.0
477
        precision is zero_division if there are no positive predictions
478
        recall is zero_division if there are no positive labels
479
        fscore is zero_division if all labels AND predictions are negative
480
    """
481

482
    def __init__(self, num_labels, name="multi_labels_metric"):
483
        super(MultiLabelsMetric, self).__init__()
484
        if num_labels <= 1:
485
            raise ValueError(f"The num_labels is {num_labels}, which must be greater than 1.")
486
        self.num_labels = num_labels
487
        self._name = name
488
        self._confusion_matrix = np.zeros((num_labels, 2, 2), dtype=int)
489

490
    def update(self, args):
491
        """
492
        Updates the metrics states (accuracy, precision and recall), in order to
493
        calculate accumulated accuracy, precision and recall of all instances.
494

495
        Args:
496
            args (tuple of Tensor):
497
                the tuple returned from `compute` function
498
        """
499
        pred = args[0].numpy()
500
        label = args[1].numpy()
501
        tmp_confusion_matrix = self._multi_labels_confusion_matrix(pred, label)
502
        self._confusion_matrix += tmp_confusion_matrix
503

504
    def accumulate(self, average=None, pos_label=1):
505
        """
506
        Calculates and returns the accumulated metric.
507

508
        Args:
509
            average (str in {‘binary’, ‘micro’, ‘macro’, ’weighted’} or None, optional):
510
            Defaults to `None`. If `None`, the scores for each class are returned.
511
            Otherwise, this determines the type of averaging performed on the data:
512

513
            - `binary` :
514
                Only report results for the class specified by pos_label.
515

516
            - `micro` :
517
                Calculate metrics globally by counting the total true positives,
518
                false negatives and false positives.
519

520
            - `macro` :
521
                Calculate metrics for each label, and find their unweighted mean.
522
                This does not take label imbalance into account.
523

524
            - `weighted` :
525
                Calculate metrics for each label, and find their average weighted
526
                by support (the number of true instances for each label). This
527
                alters `macro` to account for label imbalance; it can result in
528
                an F-score that is not between precision and recall.
529

530
            pos_label (int, optional):
531
                The positive label for calculating precision and recall in binary settings.
532
                Noted: Only when `average='binary'`, this arguments will be used. Otherwise,
533
                it will be ignored.
534
                Defaults to 1.
535

536
        Returns:
537
            tuple: The accumulated metric. A tuple of shape (precision, recall, f1)
538
                With the fields:
539

540
                - `precision` (numpy.float64 or numpy.ndarray if average=None):
541
                    The accumulated precision.
542
                - `recall` (numpy.float64 or numpy.ndarray if average=None):
543
                    The accumulated recall.
544
                - `f1` (numpy.float64 or numpy.ndarray if average=None):
545
                    The accumulated f1.
546

547
        """
548
        if average not in {"binary", "micro", "macro", "weighted", None}:
549
            raise ValueError(f"The average is {average}, which is unknown.")
550
        if average == "binary":
551
            if pos_label >= self.num_labels:
552
                raise ValueError(
553
                    f"The pos_label is {pos_label}, num_labels is {self.num_labels}. "
554
                    f"The num_labels must be greater than pos_label."
555
                )
556

557
        confusion_matrix = None  # [*, 2, 2]
558
        if average == "binary":
559
            confusion_matrix = np.expand_dims(self._confusion_matrix[pos_label], axis=0)
560
        elif average == "micro":
561
            confusion_matrix = self._confusion_matrix.sum(axis=0, keepdims=True)
562
        #  if average is 'macro' or 'weighted' or None
563
        else:
564
            confusion_matrix = self._confusion_matrix
565

566
        tp = confusion_matrix[:, 1, 1]  # [*,]
567
        pred = tp + confusion_matrix[:, 0, 1]  # [*,]
568
        true = tp + confusion_matrix[:, 1, 0]  # [*,]
569

570
        def _robust_divide(numerator, denominator, metric_name):
571
            mask = denominator == 0.0
572
            denominator = denominator.copy()
573
            denominator[mask] = 1  # avoid zero division
574
            result = numerator / denominator
575

576
            if not np.any(mask):
577
                return result
578

579
            # precision is zero_division if there are no positive predictions
580
            # recall is zero_division if there are no positive labels
581
            # fscore is zero_division if all labels AND predictions are negative
582
            warnings.warn(f"Zero division when calculating {metric_name}.", UserWarning)
583
            result[mask] = 0.0
584
            return result
585

586
        precision = _robust_divide(tp, pred, "precision")
587
        recall = _robust_divide(tp, true, "recall")
588
        f1 = _robust_divide(2 * (precision * recall), (precision + recall), "f1")
589

590
        weights = None  # [num_labels]
591
        if average == "weighted":
592
            weights = true
593
            if weights.sum() == 0:
594
                zero_division_value = np.float64(0.0)
595
                if pred.sum() == 0:
596
                    return (zero_division_value, zero_division_value, zero_division_value)
597
                else:
598
                    return (np.float64(0.0), zero_division_value, np.float64(0.0))
599
        elif average == "macro":
600
            weights = np.ones((self.num_labels), dtype=float)
601
        if average is not None:
602
            precision = np.average(precision, weights=weights)
603
            recall = np.average(recall, weights=weights)
604
            f1 = np.average(f1, weights=weights)
605

606
        return precision, recall, f1
607

608
    def compute(self, pred, label):
609
        """
610
        Accepts network's output and the labels, and calculates the top-k
611
        (maximum value in topk) indices for accuracy.
612

613
        Args:
614
            pred (Tensor):
615
                Predicted tensor, and its dtype is float32 or float64, and
616
                has a shape of [batch_size, *, num_labels].
617
            label (Tensor):
618
                The ground truth tensor, and its dtype is int64, and has a
619
                shape of [batch_size, *] or [batch_size, *, num_labels] in one
620
                hot representation.
621

622
        Returns:
623
            tuple of Tensor: it contains two Tensor of shape [*, 1].
624
            The tuple should be passed to `update` function.
625
        """
626
        if not (paddle.is_tensor(pred) and paddle.is_tensor(label)):
627
            raise ValueError("pred and label must be paddle tensor")
628

629
        if pred.shape[-1] != self.num_labels:
630
            raise ValueError(f"The last dim of pred is {pred.shape[-1]}, " f"which should be num_labels")
631
        pred = paddle.reshape(pred, [-1, self.num_labels])
632
        pred = paddle.argmax(pred, axis=-1)
633

634
        if label.shape[-1] == self.num_labels:
635
            label = paddle.reshape(label, [-1, self.num_labels])
636
            label = paddle.argmax(label, axis=-1)
637
        else:
638
            label = paddle.reshape(label, [-1])
639
            if paddle.max(label) >= self.num_labels:
640
                raise ValueError(f"Tensor label has value {paddle.max(label)}, " f"which is no less than num_labels")
641

642
        if pred.shape[0] != label.shape[0]:
643
            raise ValueError("The length of pred is not equal to the length of label")
644

645
        return pred, label
646

647
    def _multi_labels_confusion_matrix(self, pred, label):
648
        tp_bins = label[pred == label]
649
        tp = np.bincount(tp_bins, minlength=self.num_labels)  # [num_labels,]
650
        tp_plus_fp = np.bincount(pred, minlength=self.num_labels)  # [num_labels,]
651
        tp_plus_fn = np.bincount(label, minlength=self.num_labels)  # [num_labels,]
652
        fp = tp_plus_fp - tp  # [num_labels,]
653
        fn = tp_plus_fn - tp  # [num_labels,]
654
        tn = pred.shape[0] - tp - fp - fn  # [num_labels,]
655
        return np.array([tn, fp, fn, tp]).T.reshape(-1, 2, 2)  # [num_labels, 2, 2]
656

657
    def reset(self):
658
        self._confusion_matrix = np.zeros((self.num_labels, 2, 2), dtype=int)
659

660
    def name(self):
661
        """
662
        Returns name of the metric instance.
663

664
        Returns:
665
           str: The name of the metric instance.
666

667
        """
668
        return self._name
669

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.