LLM-FineTuning-Large-Language-Models

Форк
0
/
MoE_implementation_Mistral_official_Repo.ipynb 
127 строк · 6.6 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "## Implementation of Sparse Mixtures-of-Experts layer in PyTorch from Mistral Official Repo\n",
8
    "\n",
9
    "📌 https://github.com/mistralai/mistral-src/blob/main/mistral/moe.py\n",
10
    "\n",
11
    "And its super simple.\n",
12
    "\n",
13
    "--------"
14
   ]
15
  },
16
  {
17
   "cell_type": "code",
18
   "execution_count": null,
19
   "metadata": {},
20
   "outputs": [],
21
   "source": [
22
    "import dataclasses\n",
23
    "from typing import List\n",
24
    "\n",
25
    "import torch\n",
26
    "import torch.nn.functional as F\n",
27
    "from simple_parsing.helpers import Serializable\n",
28
    "from torch import nn\n",
29
    "\n",
30
    "\n",
31
    "@dataclasses.dataclass\n",
32
    "class MoeArgs(Serializable):\n",
33
    "    num_experts: int\n",
34
    "    num_experts_per_tok: int\n",
35
    "\n",
36
    "\n",
37
    "class MoeLayer(nn.Module):\n",
38
    "    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):\n",
39
    "        super().__init__()\n",
40
    "        assert len(experts) > 0\n",
41
    "        self.experts = nn.ModuleList(experts)\n",
42
    "        self.gate = gate\n",
43
    "        self.args = moe_args\n",
44
    "\n",
45
    "    def forward(self, inputs: torch.Tensor):\n",
46
    "        gate_logits = self.gate(inputs)\n",
47
    "        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)\n",
48
    "        weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)\n",
49
    "        results = torch.zeros_like(inputs)\n",
50
    "        for i, expert in enumerate(self.experts):\n",
51
    "            batch_idx, nth_expert = torch.where(selected_experts == i)\n",
52
    "            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(\n",
53
    "                inputs[batch_idx]\n",
54
    "            )\n",
55
    "        return results"
56
   ]
57
  },
58
  {
59
   "cell_type": "markdown",
60
   "metadata": {},
61
   "source": [
62
    "📌 `torch.topk()` is used over the gate outputs to find the best expert per training example. It computes the top `num_experts_per_tok` logits for each token across the expert dimension. This operation returns two tensors: the top logits (`weights`) and their corresponding expert indices (`selected_experts`). \n",
63
    "\n",
64
    "📌 Then `torch.where()` to determine which training examples in the batch should be routed to which expert and so uses `selected_experts` to map each token to its allocated experts. The gating mechanism's sparsity is embodied here, as each token is only routed to a limited set of experts (as defined by `num_experts_per_tok`), rather than all available experts.\n",
65
    "\n",
66
    "📌 `torch.where(selected_experts == i)` is used to find indices in `selected_experts` where its elements equal `i`. This returns two tensors:\n",
67
    "- **batch_idx**: The indices of the batch dimension where the condition holds true.\n",
68
    "- **nth_expert**: The indices along the second dimension (the expert dimension in this context) for each true element in the condition.\n",
69
    "\n",
70
    "📌 The softmax applied to `weights` normalizes these logits, converting them into a probability distribution over the selected experts for each token. This step ensures that the contribution of each selected expert is weighted proportionally to its predicted relevance."
71
   ]
72
  },
73
  {
74
   "cell_type": "markdown",
75
   "metadata": {},
76
   "source": [
77
    "----------------\n",
78
    "\n",
79
    "More Explanations on the 2 key steps 🔽\n",
80
    "\n",
81
    "\n",
82
    "📌 `torch.topk()` returns the `k` largest elements from the given input tensor along a specified dimension. The function returns two tensors: the first contains the top `k` values, and the second contains the indices of these values in the tensor.\n",
83
    "\n",
84
    "Here,\n",
85
    "\n",
86
    "- `gate_logits` represents the output from the gating mechanism, which is essentially the scores or logits indicating how much each training example is relevant to each expert.\n",
87
    "\n",
88
    "- `torch.topk(gate_logits, self.args.num_experts_per_tok)` finds the top `k` experts (where `k` is `self.args.num_experts_per_tok`) for each token or training example. The returned values are:\n",
89
    "    \n",
90
    "    - `weights`: The scores or probabilities of each of the top `k` experts (i.e. the gate logits)\n",
91
    "    - `selected_experts`: The indices of these top `k` experts.\n",
92
    "\n",
93
    "📌 The `torch.topk` function, by default, operates on the last dimension of the input tensor unless otherwise specified by the `dim` argument. Since `gate_logits` is not explicitly reshaped or permuted in the code before the `topk` call, and the `dim` argument is not provided, it is logical to deduce that the operation is performed across the expert dimension, which is the last dimension in the `gate_logits` tensor. "
94
   ]
95
  },
96
  {
97
   "cell_type": "markdown",
98
   "metadata": {},
99
   "source": [
100
    "------------\n",
101
    "\n",
102
    "📌 `torch.where()` is used for conditional selection of elements from tensors. The function's signature is `torch.where(condition, x, y)`. It takes three arguments:\n",
103
    "\n",
104
    "- **condition**: A boolean tensor. The shape of the condition tensor dictates the shape of the output.\n",
105
    "- **x**: Tensor (or scalar) from which to take elements when the corresponding value in `condition` is `True`.\n",
106
    "- **y**: Tensor (or scalar) from which to take elements when the corresponding value in `condition` is `False`.\n",
107
    "\n",
108
    "📌 The output tensor is formed by selecting elements from `x` or `y` based on the `condition`. If `condition[i, j, ...] == True`, the output at that location is `x[i, j, ...]`; otherwise, it is `y[i, j, ...]`. \n",
109
    "\n",
110
    "📌 In this implementation here, `torch.where()` is used differently. It's used to find indices where a condition is true.\n",
111
    "\n",
112
    "📌 Here, `torch.where(selected_experts == i)` is used to find indices in `selected_experts` where its elements equal `i`. This returns two tensors:\n",
113
    "- **batch_idx**: The indices of the batch dimension where the condition holds true.\n",
114
    "- **nth_expert**: The indices along the second dimension (the expert dimension in this context) for each true element in the condition.\n",
115
    "\n",
116
    "📌 These indices (`batch_idx` and `nth_expert`) are then used to route the inputs to the appropriate expert in the Mixture of Experts (MoE) layer. For each expert `i`, it finds which inputs (`inputs[batch_idx]`) should be processed by that expert. The results are scaled by the corresponding weights (`weights[batch_idx, nth_expert, None]`) and accumulated in `results`."
117
   ]
118
  }
119
 ],
120
 "metadata": {
121
  "language_info": {
122
   "name": "python"
123
  }
124
 },
125
 "nbformat": 4,
126
 "nbformat_minor": 2
127
}
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.