gpt-neox

Форк
0
/
common.py 
630 строк · 22.4 Кб
1
# Copyright (c) 2024, EleutherAI
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import time
17
import shutil
18
import itertools
19
from pathlib import Path
20
from abc import ABC, abstractmethod
21
from deepspeed.accelerator import get_accelerator
22

23
import pytest
24
from _pytest.outcomes import Skipped
25
from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
26
import random
27
import train
28

29
import torch
30

31
import torch.distributed as dist
32
from torch.multiprocessing import Process
33
import torch.multiprocessing as mp
34
from yaml import load
35

36
try:
37
    from yaml import CLoader as Loader, CDumper as Dumper
38
except ImportError:
39
    from yaml import Loader, Dumper
40
from copy import deepcopy
41
import deepspeed
42

43
TEST_CHECKPOINT_DIR = "test_checkpoint"
44
TEST_LOG_DIR = "test_logs"
45
TEST_TENSORBOARD_DIR = "test_tensorboard"
46

47
# Worker timeout *after* the first worker has completed.
48
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
49
DEEPSPEED_TEST_TIMEOUT = 600
50

51

52
def get_xdist_worker_id():
53
    xdist_worker = os.environ.get("PYTEST_XDIST_WORKER", None)
54
    if xdist_worker is not None:
55
        xdist_worker_id = xdist_worker.replace("gw", "")
56
        return int(xdist_worker_id)
57
    return None
58

59

60
def get_master_port():
61
    master_port = os.environ.get("DS_TEST_PORT", "29503")
62
    xdist_worker_id = get_xdist_worker_id()
63
    if xdist_worker_id is not None:
64
        master_port = str(int(master_port) + xdist_worker_id)
65
    return master_port
66

67

68
_num_gpus = None
69

70

71
def set_accelerator_visible():
72
    cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
73
    xdist_worker_id = get_xdist_worker_id()
74
    if xdist_worker_id is None:
75
        xdist_worker_id = 0
76
    if cuda_visible is None:
77
        # CUDA_VISIBLE_DEVICES is not set, discover it using accelerator specific command instead
78
        if get_accelerator().device_name() == "cuda":
79
            if is_rocm_pytorch():
80
                rocm_smi = subprocess.check_output(["rocm-smi", "--showid"])
81
                gpu_ids = filter(
82
                    lambda s: "GPU" in s, rocm_smi.decode("utf-8").strip().split("\n")
83
                )
84
                num_accelerators = len(list(gpu_ids))
85
            else:
86
                nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"])
87
                num_accelerators = len(nvidia_smi.decode("utf-8").strip().split("\n"))
88
        elif get_accelerator().device_name() == "xpu":
89
            clinfo = subprocess.check_output(["clinfo"])
90
            lines = clinfo.decode("utf-8").strip().split("\n")
91
            num_accelerators = 0
92
            for line in lines:
93
                match = re.search("Device Type.*GPU", line)
94
                if match:
95
                    num_accelerators += 1
96
        elif get_accelerator().device_name() == "npu":
97
            npu_smi = subprocess.check_output(["npu-smi", "info", "-l"])
98
            num_accelerators = int(
99
                npu_smi.decode("utf-8").strip().split("\n")[0].split(":")[1].strip()
100
            )
101
        else:
102
            assert get_accelerator().device_name() == "cpu"
103
            cpu_sockets = int(
104
                subprocess.check_output(
105
                    'cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l',
106
                    shell=True,
107
                )
108
            )
109
            num_accelerators = cpu_sockets
110

111
        cuda_visible = ",".join(map(str, range(num_accelerators)))
112

113
    # rotate list based on xdist worker id, example below
114
    # wid=0 -> ['0', '1', '2', '3']
115
    # wid=1 -> ['1', '2', '3', '0']
116
    # wid=2 -> ['2', '3', '0', '1']
117
    # wid=3 -> ['3', '0', '1', '2']
118
    dev_id_list = cuda_visible.split(",")
119
    dev_id_list = dev_id_list[xdist_worker_id:] + dev_id_list[:xdist_worker_id]
120
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)
121

122

123
def count_gpus():
124
    global _num_gpus
125
    if _num_gpus is None:
126
        import subprocess
127

128
        nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"])
129
        _num_gpus = len(nvidia_smi.decode("utf-8").strip().split("\n"))
130
    return _num_gpus
131

132

133
def set_cuda_visibile():
134
    cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
135
    xdist_worker_id = get_xdist_worker_id()
136
    if xdist_worker_id is None:
137
        xdist_worker_id = 0
138
    if cuda_visible is None:
139
        # CUDA_VISIBLE_DEVICES is not set, discover it from nvidia-smi instead
140
        import subprocess
141

142
        nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"])
143
        num_gpus = len(nvidia_smi.decode("utf-8").strip().split("\n"))
144
        cuda_visible = ",".join(map(str, range(num_gpus)))
145

146
    # rotate list based on xdist worker id, example below
147
    # wid=0 -> ['0', '1', '2', '3']
148
    # wid=1 -> ['1', '2', '3', '0']
149
    # wid=2 -> ['2', '3', '0', '1']
150
    # wid=3 -> ['3', '0', '1', '2']
151
    dev_id_list = cuda_visible.split(",")
152
    dev_id_list = dev_id_list[xdist_worker_id:] + dev_id_list[:xdist_worker_id]
153
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)
154

155

156
def get_root_directory():
157
    return Path(__file__).parents[1]
158

159

160
def get_config_directory():
161
    return get_root_directory() / "configs"
162

163

164
def get_configs_with_path(configs):
165
    return [str(get_config_directory() / cfg) for cfg in configs]
166

167

168
def clear_test_dirs():
169
    log_dir = os.path.join(get_root_directory(), TEST_LOG_DIR)
170
    if os.path.isdir(log_dir):
171
        shutil.rmtree(log_dir)
172

173
    checkpoint_dir = os.path.join(get_root_directory(), TEST_CHECKPOINT_DIR)
174
    if os.path.isdir(checkpoint_dir):
175
        shutil.rmtree(checkpoint_dir)
176

177
    tensorboard_dir = os.path.join(get_root_directory(), TEST_TENSORBOARD_DIR)
178
    if os.path.isdir(tensorboard_dir):
179
        shutil.rmtree(tensorboard_dir)
180

181

182
class DistributedExec(ABC):
183
    """
184
    Base class for distributed execution of functions/methods. Contains common
185
    methods needed for DistributedTest and DistributedFixture.
186
    """
187

188
    world_size = 2
189
    backend = get_accelerator().communication_backend_name()
190
    init_distributed = True
191
    set_dist_env = True
192
    requires_cuda_env = True
193
    reuse_dist_env = False
194
    _pool_cache = {}
195
    exec_timeout = DEEPSPEED_TEST_TIMEOUT
196

197
    @abstractmethod
198
    def run(self):
199
        ...
200

201
    def __call__(self, request=None):
202
        self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
203
        world_size = self.world_size
204
        if self.requires_cuda_env and not get_accelerator().is_available():
205
            pytest.skip("only supported in accelerator environments.")
206

207
        if isinstance(world_size, int):
208
            world_size = [world_size]
209
        for procs in world_size:
210
            self._launch_procs(procs)
211

212
    def _get_fixture_kwargs(self, request, func):
213
        if not request:
214
            return {}
215
        # Grab fixture / parametrize kwargs from pytest request object
216
        fixture_kwargs = {}
217
        params = inspect.getfullargspec(func).args
218
        params.remove("self")
219
        for p in params:
220
            try:
221
                fixture_kwargs[p] = request.getfixturevalue(p)
222
            except FixtureLookupError:
223
                pass  # test methods can have kwargs that are not fixtures
224
        return fixture_kwargs
225

226
    def _launch_procs(self, num_procs):
227
        # Verify we have enough accelerator devices to run this test
228
        if (
229
            get_accelerator().is_available()
230
            and get_accelerator().device_count() < num_procs
231
        ):
232
            pytest.skip(
233
                f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
234
            )
235

236
        mp.set_start_method("spawn", force=True)
237

238
        # Create process pool or use cached one
239
        master_port = None
240
        if self.reuse_dist_env:
241
            if num_procs not in self._pool_cache:
242
                self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
243
                master_port = get_master_port()
244
            pool = self._pool_cache[num_procs]
245
        else:
246
            pool = mp.Pool(processes=num_procs)
247
            master_port = get_master_port()
248

249
        # Run the test
250
        args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
251
        skip_msgs_async = pool.starmap_async(self._dist_run, args)
252

253
        try:
254
            skip_msgs = skip_msgs_async.get(self.exec_timeout)
255
        except mp.TimeoutError:
256
            # Shortcut to exit pytest in the case of a hanged test. This
257
            # usually means an environment error and the rest of tests will
258
            # hang (causing super long unit test runtimes)
259
            pytest.exit("Test hanged, exiting", returncode=0)
260

261
        # Tear down distributed environment and close process pools
262
        self._close_pool(pool, num_procs)
263

264
        # If we skipped a test, propagate that to this process
265
        if any(skip_msgs):
266
            assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
267
            pytest.skip(skip_msgs[0])
268

269
    def _dist_run(self, local_rank, num_procs, master_port):
270
        skip_msg = ""
271
        if not dist.is_initialized():
272
            """Initialize deepspeed.comm and execute the user function."""
273
            if self.set_dist_env:
274
                os.environ["MASTER_ADDR"] = "127.0.0.1"
275
                os.environ["MASTER_PORT"] = str(master_port)
276
                os.environ["LOCAL_RANK"] = str(local_rank)
277
                # NOTE: unit tests don't support multi-node so local_rank == global rank
278
                os.environ["RANK"] = str(local_rank)
279
                # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE
280
                # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly
281
                os.environ["LOCAL_SIZE"] = str(num_procs)
282
                os.environ["WORLD_SIZE"] = str(num_procs)
283

284
            # turn off NCCL logging if set
285
            os.environ.pop("NCCL_DEBUG", None)
286

287
            if get_accelerator().is_available():
288
                set_accelerator_visible()
289

290
            if get_accelerator().is_available():
291
                get_accelerator().set_device(local_rank)
292

293
            if self.init_distributed:
294
                deepspeed.init_distributed(dist_backend=self.backend)
295
                dist.barrier()
296

297
        try:
298
            self.run(**self._fixture_kwargs)
299
        except BaseException as e:
300
            if isinstance(e, Skipped):
301
                skip_msg = e.msg
302
            else:
303
                raise e
304

305
        return skip_msg
306

307
    def _dist_destroy(self):
308
        if (dist is not None) and dist.is_initialized():
309
            dist.barrier()
310
            dist.destroy_process_group()
311

312
    def _close_pool(self, pool, num_procs, force=False):
313
        if force or not self.reuse_dist_env:
314
            msg = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
315
            pool.close()
316
            pool.join()
317

318

319
class DistributedFixture(DistributedExec):
320
    """
321
    Implementation that extends @pytest.fixture to allow for distributed execution.
322
    This is primarily meant to be used when a test requires executing two pieces of
323
    code with different world sizes.
324

325
    There are 2 parameters that can be modified:
326
        - world_size: int = 2 -- the number of processes to launch
327
        - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
328

329
    Features:
330
        - able to call pytest.skip() inside fixture
331
        - can be reused by multiple tests
332
        - can accept other fixtures as input
333

334
    Limitations:
335
        - cannot use @pytest.mark.parametrize
336
        - world_size cannot be modified after definition and only one world_size value is accepted
337
        - any fixtures used must also be used in the test that uses this fixture (see example below)
338
        - return values cannot be returned. Passing values to a DistributedTest
339
          object can be achieved using class_tmpdir and writing to file (see example below)
340

341
    Usage:
342
        - must implement a run(self, ...) method
343
        - fixture can be used by making the class name input to a test function
344

345
    Example:
346
        @pytest.fixture(params=[10,20])
347
        def regular_pytest_fixture(request):
348
            return request.param
349

350
        class distributed_fixture_example(DistributedFixture):
351
            world_size = 4
352

353
            def run(self, regular_pytest_fixture, class_tmpdir):
354
                assert int(os.environ["WORLD_SIZE"]) == self.world_size
355
                local_rank = os.environ["LOCAL_RANK"]
356
                print(f"Rank {local_rank} with value {regular_pytest_fixture}")
357
                with open(os.path.join(class_tmpdir, f"{local_rank}.txt"), "w") as f:
358
                    f.write(f"{local_rank},{regular_pytest_fixture}")
359

360
        class TestExample(DistributedTest):
361
            world_size = 1
362

363
            def test(self, distributed_fixture_example, regular_pytest_fixture, class_tmpdir):
364
                assert int(os.environ["WORLD_SIZE"]) == self.world_size
365
                for rank in range(4):
366
                    with open(os.path.join(class_tmpdir, f"{rank}.txt"), "r") as f:
367
                        assert f.read() == f"{rank},{regular_pytest_fixture}"
368
    """
369

370
    is_dist_fixture = True
371

372
    # These values are just placeholders so that pytest recognizes this as a fixture
373
    _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None)
374
    __name__ = ""
375

376
    def __init__(self):
377
        assert isinstance(
378
            self.world_size, int
379
        ), "Only one world size is allowed for distributed fixtures"
380
        self.__name__ = type(self).__name__
381
        _pytestfixturefunction = FixtureFunctionMarker(
382
            scope="function", params=None, name=self.__name__
383
        )
384

385

386
class DistributedTest(DistributedExec):
387
    """
388
    Implementation for running pytest with distributed execution.
389

390
    There are 2 parameters that can be modified:
391
        - world_size: Union[int,List[int]] = 2 -- the number of processes to launch
392
        - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
393

394
    Features:
395
        - able to call pytest.skip() inside tests
396
        - works with pytest fixtures, parametrize, mark, etc.
397
        - can contain multiple tests (each of which can be parametrized separately)
398
        - class methods can be fixtures (usable by tests in this class only)
399
        - world_size can be changed for individual tests using @pytest.mark.world_size(world_size)
400
        - class_tmpdir is a fixture that can be used to get a tmpdir shared among
401
          all tests (including DistributedFixture)
402

403
    Usage:
404
        - class name must start with "Test"
405
        - must implement one or more test*(self, ...) methods
406

407
    Example:
408
        @pytest.fixture(params=[10,20])
409
        def val1(request):
410
            return request.param
411

412
        @pytest.mark.fast
413
        @pytest.mark.parametrize("val2", [30,40])
414
        class TestExample(DistributedTest):
415
            world_size = 2
416

417
            @pytest.fixture(params=[50,60])
418
            def val3(self, request):
419
                return request.param
420

421
            def test_1(self, val1, val2, str1="hello world"):
422
                assert int(os.environ["WORLD_SIZE"]) == self.world_size
423
                assert all(val1, val2, str1)
424

425
            @pytest.mark.world_size(1)
426
            @pytest.mark.parametrize("val4", [70,80])
427
            def test_2(self, val1, val2, val3, val4):
428
                assert int(os.environ["WORLD_SIZE"]) == 1
429
                assert all(val1, val2, val3, val4)
430
    """
431

432
    def __init__(self):
433
        self.is_dist_test = True
434

435
    # Temporary directory that is shared among test methods in a class
436
    @pytest.fixture(autouse=True, scope="class")
437
    def class_tmpdir(self, tmpdir_factory):
438
        fn = tmpdir_factory.mktemp(self.__class__.__name__)
439
        return fn
440

441
    def run(self, **fixture_kwargs):
442
        self._current_test(**fixture_kwargs)
443

444
    def __call__(self, request):
445
        self._current_test = self._get_current_test_func(request)
446
        self._fixture_kwargs = self._get_fixture_kwargs(request, self._current_test)
447

448
        if self.requires_cuda_env and not get_accelerator().is_available():
449
            pytest.skip("only supported in accelerator environments.")
450

451
        # Catch world_size override pytest mark
452
        for mark in getattr(request.function, "pytestmark", []):
453
            if mark.name == "world_size":
454
                world_size = mark.args[0]
455
                break
456
        else:
457
            world_size = self.world_size
458

459
        if isinstance(world_size, int):
460
            world_size = [world_size]
461
        for procs in world_size:
462
            self._launch_procs(procs)
463
            time.sleep(0.5)
464

465
    def _get_current_test_func(self, request):
466
        # DistributedTest subclasses may have multiple test methods
467
        func_name = request.function.__name__
468
        return getattr(self, func_name)
469

470

471
def get_test_path(filename):
472
    curr_path = Path(__file__).parent
473
    return str(curr_path.joinpath(filename))
474

475

476
def model_setup(yaml_list=None, param_dict=None, clear_data=True):
477
    from megatron.neox_arguments import NeoXArgs
478
    from megatron.mpu import destroy_model_parallel
479
    from megatron import initialize_megatron
480
    from megatron.training import setup_model_and_optimizer
481

482
    destroy_model_parallel()  # mpu model parallel contains remaining global vars
483
    if clear_data and (
484
        not torch.distributed.is_initialized()
485
        or torch.distributed.get_world_size() == 1
486
        or torch.distributed.get_rank() == 0
487
    ):
488
        clear_test_dirs()
489

490
    overwrite_values = {
491
        "user_script": str(get_root_directory() / "train.py"),
492
        "save": TEST_CHECKPOINT_DIR,
493
        "load": TEST_CHECKPOINT_DIR,
494
        "log_dir": TEST_LOG_DIR,
495
        "tensorboard_dir": TEST_TENSORBOARD_DIR,
496
    }
497

498
    # should not both be none
499
    assert yaml_list is not None or param_dict is not None
500

501
    # initially load config from files as would be the case in deepy.py
502
    if yaml_list is not None:
503
        args_loaded = NeoXArgs.from_ymls(yaml_list, overwrite_values=overwrite_values)
504
    else:
505
        p_dict = param_dict.copy()
506
        p_dict.update(overwrite_values)
507
        args_loaded = NeoXArgs.from_dict(p_dict)
508

509
    args_loaded.build_tokenizer()
510

511
    initialize_megatron(neox_args=args_loaded)
512
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
513
        neox_args=args_loaded, use_cache=True
514
    )
515
    return model, optimizer, lr_scheduler, args_loaded
516

517

518
def simulate_deepy_env(monkeypatch, input_args):
519
    from megatron.neox_arguments import NeoXArgs
520

521
    monkeypatch.setenv("WORLD_SIZE", "1")
522
    monkeypatch.setenv("RANK", "0")
523
    neox_args = NeoXArgs.consume_deepy_args(input_args)
524
    deepspeed_main_args = neox_args.get_deepspeed_main_args()
525
    return deepspeed_main_args
526

527

528
def save_random_model(input_args, model_dir, train_iters=0):
529
    # Save randomly initialised model
530
    train_args = {
531
        "do_train": False,
532
        "train_iters": train_iters,
533
        "save": model_dir,
534
        "extra_save_iters": [train_iters],
535
    }
536
    train.main(input_args=input_args, overwrite_values=train_args)
537

538

539
def bounded_product(sequence, n=None, seed=None):
540
    """
541
    Returns a shuffled, bounded cartesian product of the input sequence.
542
    Designed to cover as wide a range of permutations as possible with a limited number of iterations.
543
    Will manifest the whole list in memory, so not suitable for super large sequences.
544

545
    :param sequence: iterable
546
    :param n: length of returned list
547
    :param seed: random seed for reproducibility
548
    :return: list
549
    """
550
    p = list(itertools.product(*sequence))
551
    if seed is not None:
552
        random.seed(seed)
553
    random.shuffle(p)
554
    return p if n is None else p[:n]
555

556

557
def model_setup_simple(deepspeed_main_args, overwrite_values, iteration=None):
558
    from megatron.neox_arguments import NeoXArgs
559
    from megatron import initialize_megatron
560
    from megatron.training import setup_model_and_optimizer
561

562
    neox_args = NeoXArgs.consume_neox_args(
563
        input_args=deepspeed_main_args, overwrite_values=overwrite_values
564
    )
565
    neox_args.configure_distributed_args()
566
    neox_args.build_tokenizer()
567
    initialize_megatron(neox_args=neox_args)
568
    model, optimizer, lr_scheduler = setup_model_and_optimizer(
569
        neox_args=neox_args, use_cache=False
570
    )
571
    return model, optimizer, lr_scheduler, neox_args
572

573

574
def parametrize(
575
    params_to_test: dict, max_tests: int = 50, seed: int = None, with_names=True
576
):
577
    """
578
    Generates a random sample of max_tests length of all possible combinations of values in
579
    `params_to_test`.
580

581
    In `params_to_test` you can either specify one value, and all possible settings of that value,
582
    or two values separated by a comma, and all possible combinations of those two values in tandem.
583
        i.e "hidden_size,num_heads": [[768,12], [1024,32], [2048, 64]]
584
    so the first item in each list is a value of `hidden_size` and the second a value of `num_heads`
585
    this is useful for reducing the size of possible tests for values we know are unlikely to interact beforehand,
586
    since the cartesian product can grow very large.
587

588
    :param params_to_test: dict of neox params
589
    :param max_tests: maximum number of tests to run
590
    :param seed: random seed
591
    :return: a list of neox param dicts to pass to a parametrized unit test
592
    """
593
    keys, values = zip(*params_to_test.items())
594
    ret = []
595
    if with_names:
596
        experiments = []
597
    for p in bounded_product(values, n=max_tests, seed=seed):
598
        experiment = dict(zip(keys, p))
599
        to_pop = []
600
        to_add = {}
601
        for k, v in experiment.items():
602
            if "," in k:
603
                keys_split = [i.strip() for i in k.split(",")]
604
                values_separated = experiment[k]
605
                to_pop.append(k)
606
                assert len(values_separated) == len(keys_split)
607
                new_dict = dict(zip(keys_split, values_separated))
608
                to_add.update(new_dict)
609
        experiment.update(to_add)
610
        for k in to_pop:
611
            experiment.pop(k)
612
        base = deepcopy(BASE_CONFIG)
613
        base.update(experiment)
614
        ret.append(base)
615
        if with_names:
616
            experiments.append(experiment)
617
    if with_names:
618
        return ret, [dict_repr(d) for d in experiments]
619
    return ret
620

621

622
def dict_repr(d):
623
    return " ".join([f"{str(k)} : {str(v)}" for k, v in d.items()])
624

625

626
binary = [True, False]
627

628
with open("tests/config/test_setup.yml", "r") as f:
629
    BASE_CONFIG = load(f, Loader=Loader)
630
    print(f"Base Config:\n{BASE_CONFIG}")
631

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.