LLM-FineTuning-Large-Language-Models
/
MoE_implementation_Mistral_official_Repo.ipynb
127 строк · 6.6 Кб
1{
2"cells": [
3{
4"cell_type": "markdown",
5"metadata": {},
6"source": [
7"## Implementation of Sparse Mixtures-of-Experts layer in PyTorch from Mistral Official Repo\n",
8"\n",
9"📌 https://github.com/mistralai/mistral-src/blob/main/mistral/moe.py\n",
10"\n",
11"And its super simple.\n",
12"\n",
13"--------"
14]
15},
16{
17"cell_type": "code",
18"execution_count": null,
19"metadata": {},
20"outputs": [],
21"source": [
22"import dataclasses\n",
23"from typing import List\n",
24"\n",
25"import torch\n",
26"import torch.nn.functional as F\n",
27"from simple_parsing.helpers import Serializable\n",
28"from torch import nn\n",
29"\n",
30"\n",
31"@dataclasses.dataclass\n",
32"class MoeArgs(Serializable):\n",
33" num_experts: int\n",
34" num_experts_per_tok: int\n",
35"\n",
36"\n",
37"class MoeLayer(nn.Module):\n",
38" def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):\n",
39" super().__init__()\n",
40" assert len(experts) > 0\n",
41" self.experts = nn.ModuleList(experts)\n",
42" self.gate = gate\n",
43" self.args = moe_args\n",
44"\n",
45" def forward(self, inputs: torch.Tensor):\n",
46" gate_logits = self.gate(inputs)\n",
47" weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)\n",
48" weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)\n",
49" results = torch.zeros_like(inputs)\n",
50" for i, expert in enumerate(self.experts):\n",
51" batch_idx, nth_expert = torch.where(selected_experts == i)\n",
52" results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(\n",
53" inputs[batch_idx]\n",
54" )\n",
55" return results"
56]
57},
58{
59"cell_type": "markdown",
60"metadata": {},
61"source": [
62"📌 `torch.topk()` is used over the gate outputs to find the best expert per training example. It computes the top `num_experts_per_tok` logits for each token across the expert dimension. This operation returns two tensors: the top logits (`weights`) and their corresponding expert indices (`selected_experts`). \n",
63"\n",
64"📌 Then `torch.where()` to determine which training examples in the batch should be routed to which expert and so uses `selected_experts` to map each token to its allocated experts. The gating mechanism's sparsity is embodied here, as each token is only routed to a limited set of experts (as defined by `num_experts_per_tok`), rather than all available experts.\n",
65"\n",
66"📌 `torch.where(selected_experts == i)` is used to find indices in `selected_experts` where its elements equal `i`. This returns two tensors:\n",
67"- **batch_idx**: The indices of the batch dimension where the condition holds true.\n",
68"- **nth_expert**: The indices along the second dimension (the expert dimension in this context) for each true element in the condition.\n",
69"\n",
70"📌 The softmax applied to `weights` normalizes these logits, converting them into a probability distribution over the selected experts for each token. This step ensures that the contribution of each selected expert is weighted proportionally to its predicted relevance."
71]
72},
73{
74"cell_type": "markdown",
75"metadata": {},
76"source": [
77"----------------\n",
78"\n",
79"More Explanations on the 2 key steps 🔽\n",
80"\n",
81"\n",
82"📌 `torch.topk()` returns the `k` largest elements from the given input tensor along a specified dimension. The function returns two tensors: the first contains the top `k` values, and the second contains the indices of these values in the tensor.\n",
83"\n",
84"Here,\n",
85"\n",
86"- `gate_logits` represents the output from the gating mechanism, which is essentially the scores or logits indicating how much each training example is relevant to each expert.\n",
87"\n",
88"- `torch.topk(gate_logits, self.args.num_experts_per_tok)` finds the top `k` experts (where `k` is `self.args.num_experts_per_tok`) for each token or training example. The returned values are:\n",
89" \n",
90" - `weights`: The scores or probabilities of each of the top `k` experts (i.e. the gate logits)\n",
91" - `selected_experts`: The indices of these top `k` experts.\n",
92"\n",
93"📌 The `torch.topk` function, by default, operates on the last dimension of the input tensor unless otherwise specified by the `dim` argument. Since `gate_logits` is not explicitly reshaped or permuted in the code before the `topk` call, and the `dim` argument is not provided, it is logical to deduce that the operation is performed across the expert dimension, which is the last dimension in the `gate_logits` tensor. "
94]
95},
96{
97"cell_type": "markdown",
98"metadata": {},
99"source": [
100"------------\n",
101"\n",
102"📌 `torch.where()` is used for conditional selection of elements from tensors. The function's signature is `torch.where(condition, x, y)`. It takes three arguments:\n",
103"\n",
104"- **condition**: A boolean tensor. The shape of the condition tensor dictates the shape of the output.\n",
105"- **x**: Tensor (or scalar) from which to take elements when the corresponding value in `condition` is `True`.\n",
106"- **y**: Tensor (or scalar) from which to take elements when the corresponding value in `condition` is `False`.\n",
107"\n",
108"📌 The output tensor is formed by selecting elements from `x` or `y` based on the `condition`. If `condition[i, j, ...] == True`, the output at that location is `x[i, j, ...]`; otherwise, it is `y[i, j, ...]`. \n",
109"\n",
110"📌 In this implementation here, `torch.where()` is used differently. It's used to find indices where a condition is true.\n",
111"\n",
112"📌 Here, `torch.where(selected_experts == i)` is used to find indices in `selected_experts` where its elements equal `i`. This returns two tensors:\n",
113"- **batch_idx**: The indices of the batch dimension where the condition holds true.\n",
114"- **nth_expert**: The indices along the second dimension (the expert dimension in this context) for each true element in the condition.\n",
115"\n",
116"📌 These indices (`batch_idx` and `nth_expert`) are then used to route the inputs to the appropriate expert in the Mixture of Experts (MoE) layer. For each expert `i`, it finds which inputs (`inputs[batch_idx]`) should be processed by that expert. The results are scaled by the corresponding weights (`weights[batch_idx, nth_expert, None]`) and accumulated in `results`."
117]
118}
119],
120"metadata": {
121"language_info": {
122"name": "python"
123}
124},
125"nbformat": 4,
126"nbformat_minor": 2
127}
128