package microgopt import ( "fmt" "maps" "math" "math/rand/v2" "slices" "sort" "strings" ) // Initialize the parameters, to store the knowledge of the model const ( nLayer = 1 // depth of the transformer neural network (number of layers) nEmbd = 16 // width of the network (embedding dimension) blockSize = 16 // maximum context length of the attention window (note: the longest name is 15 characters) nHead = 4 // number of attention heads headDim = nEmbd / nHead // derived dimension of each head ) var stateMap = map[string][][]*value{} // this type pun just worked in python but go needs to be more explicit func btof(b bool) float64 { if b { return 1.0 } return 0.0 } func valcmp(a, b *value) int { if a.data < b.data { return -1 } else if a.data > b.data { return 1 } else { return 0 } } func Run(docs []string) { // remove leading and trailing whitespace in documents for i := range docs { docs[i] = strings.TrimSpace(docs[i]) } fmt.Printf("num docs: %d\n", len(docs)) // construct the vocabulary from the documents: an ordered list of all characters in the dataset, // plus a "Beginning Of Sequence" (BOS) token set := map[rune]struct{}{} for _, doc := range docs { for _, c := range doc { set[c] = struct{}{} } } uchars := slices.Sorted(maps.Keys(set)) BOS := len(uchars) vocabSize := len(uchars) + 1 fmt.Printf("vocab size: %d\n", vocabSize) // in the python code, at this point, the Value class was created // and the global parameters were set up stateMap["wte"] = genMatrix(vocabSize, nEmbd) stateMap["wpe"] = genMatrix(blockSize, nEmbd) stateMap["lm_head"] = genMatrix(vocabSize, nEmbd) for i := range nLayer { stateMap[fmt.Sprintf("layer%d.attn_wq", i)] = genMatrix(nEmbd, nEmbd) stateMap[fmt.Sprintf("layer%d.attn_wk", i)] = genMatrix(nEmbd, nEmbd) stateMap[fmt.Sprintf("layer%d.attn_wv", i)] = genMatrix(nEmbd, nEmbd) stateMap[fmt.Sprintf("layer%d.attn_wo", i)] = genMatrix(nEmbd, nEmbd) stateMap[fmt.Sprintf("layer%d.mlp_fc1", i)] = genMatrix(4*nEmbd, nEmbd) stateMap[fmt.Sprintf("layer%d.mlp_fc2", i)] = genMatrix(nEmbd, 4*nEmbd) } // flatten params into a single []value params := []*value{} for _, mat := range stateMap { for _, row := range mat { for _, p := range row { params = append(params, p) } } } fmt.Printf("num params: %d\n", len(params)) // at this point in the python, linear(), softmax(), rmsnorm(), and gpt() are all defined // "Let there be Adam, the blessed optimizer and its buffers" learningRate, beta1, beta2, epsAdam := 0.01, 0.85, 0.99, 1e-8 m := make([]float64, len(params)) // first moment buffer v := make([]float64, len(params)) // second moment buffer // Repeat in sequence numSteps := 1000 // number of training steps for step := range numSteps { // Take single document, tokenize it, surround it with BOS special token on both sides doc := docs[step%len(docs)] tokens := []int{BOS} for _, ch := range doc { tokens = append(tokens, slices.Index(uchars, ch)) } tokens = append(tokens, BOS) n := min(blockSize, len(tokens)-1) // Forward the token sequence through the model, building up the computation graph all the way to the loss keys, values := mkDeepSlice(), mkDeepSlice() losses := []*value{} for posId := range n { tokenId, targetId := tokens[posId], tokens[posId+1] logits := gpt(tokenId, posId, keys, values) probs := softMax(logits) lossT := probs[targetId].Log().Neg() losses = append(losses, lossT) } lossSum := &value{} for _, l := range losses { lossSum = lossSum.Add(l) } loss := (&value{data: 1 / float64(n)}).Mul(lossSum) // final average loss over the document sequence. May yours be low. // Backward the loss, calculating the gradients with respect to all model parameters loss.Backward() // Adam optimizer update: update the model parameters based on the corresponding gradients lrt := learningRate * (float64(1) - float64(step)/float64(numSteps)) for i, p := range params { m[i] = beta1*m[i] + (1-beta1)*p.grad v[i] = beta2*v[i] + (1-beta2)*math.Pow(p.grad, 2.0) m_hat := m[i] / (1 - math.Pow(beta1, float64(step+1))) v_hat := v[i] / (1 - math.Pow(beta2, float64(step+1))) p.data = p.data - (lrt*m_hat)/(math.Pow(v_hat, 0.5)+epsAdam) p.grad = 0.0 } fmt.Printf("step %4d / %4d | loss %.4f\r", step+1, numSteps, loss.data) } // Inference: may the model babble back to us temperature := 0.5 // in (0, 1], control the "creativity" of generated text, low to high fmt.Println("\n--- inference (new, hallucinated names) ---") for sampleIdx := range 20 { keys, values := mkDeepSlice(), mkDeepSlice() tokenId := BOS sample := []rune{} for posId := range blockSize { logits := gpt(tokenId, posId, keys, values) probs := make([]*value, len(logits)) for i, l := range logits { probs[i] = l.Div(&value{data: temperature}) } probs = softMax(probs) tokenId := choose(probs) if tokenId == BOS { break } sample = append(sample, uchars[tokenId]) } fmt.Printf("sample %2d: %s\n", sampleIdx+1, string(sample)) } } func genMatrix(out, in int) [][]*value { m := make([][]*value, out) for o := range out { m[o] = make([]*value, in) for i := range in { m[o][i] = &value{data: rand.NormFloat64() * 0.08} } } return m } func linear(x []*value, w [][]*value) []*value { r := []*value{} for _, wo := range w { s := &value{data: 0.0} for i := range wo { s = s.Add(wo[i].Mul(x[i])) } r = append(r, s) } return r } func softMax(logits []*value) []*value { maxVal := slices.MaxFunc(logits, valcmp) exps := []*value{} for _, val := range logits { exps = append(exps, val.Sub(maxVal).Exp()) } total := &value{} for _, e := range exps { total = total.Add(e) } for i := range exps { exps[i] = exps[i].Div(total) } return exps } func rmsNorm(x []*value) []*value { ms := &value{} for _, xi := range x { ms = ms.Add(xi.Mul(xi)) } ms = ms.Div(&value{data: float64(len(x))}) scale := ms.Add(&value{data: 1e-5}).Pow(&value{data: -0.5}) for i := range x { x[i] = x[i].Mul(scale) } return x } func gpt(tokenId int, posId int, keys [][][]*value, values [][][]*value) []*value { tokEmb := stateMap["wte"][tokenId] // token embedding posEmb := stateMap["wpe"][posId] // position embedding x := []*value{} // joint token and position embedding for i := range tokEmb { x = append(x, tokEmb[i].Add(posEmb[i])) } x = rmsNorm(x) // note: not redundant due to backward pass via the residual connection for li := range nLayer { // 1) Multi-head Attention block xResidual := slices.Clone(x) x = rmsNorm(x) q := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wq", li)]) k := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wk", li)]) v := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wv", li)]) keys[li] = append(keys[li], k) values[li] = append(values[li], v) xAttn := []*value{} // basically, distribute the work over the "attention heads" for h := range nHead { hs := h * headDim q_h := q[hs : hs+headDim] k_h := [][]*value{} for _, ki := range keys[li] { k_h = append(k_h, ki[hs:hs+headDim]) } v_h := [][]*value{} for _, vi := range values[li] { v_h = append(v_h, vi[hs:hs+headDim]) } attnLogits := []*value{} for t := range len(k_h) { s := &value{data: 0.0} for j := range headDim { s = s.Add(q_h[j].Mul(k_h[t][j])) } attnLogits = append(attnLogits, s.Div(&value{data: math.Pow(float64(headDim), 0.5)})) } attnWeights := softMax(attnLogits) headOut := []*value{} for j := range headDim { s := &value{data: 0.0} for t := range len(v_h) { s = s.Add(attnWeights[t].Mul(v_h[t][j])) } headOut = append(headOut, s) } xAttn = append(xAttn, headOut...) } x = linear(xAttn, stateMap[fmt.Sprintf("layer%d.attn_wo", li)]) for i := range x { x[i] = x[i].Add(xResidual[i]) } // 2) MLP block xResidual = x x = rmsNorm(x) x = linear(x, stateMap[fmt.Sprintf("layer%d.mlp_fc1", li)]) for i := range x { x[i] = x[i].Relu() } x = linear(x, stateMap[fmt.Sprintf("layer%d.mlp_fc2", li)]) for i := range x { x[i] = x[i].Add(xResidual[i]) } } logits := linear(x, stateMap["lm_head"]) return logits } type value struct { data float64 grad float64 // implicitly 0 to start lChild *value rChild *value lLocalGrad *value rLocalGrad *value } func (v *value) Add(other *value) *value { return &value{ data: v.data + other.data, lChild: v, rChild: other, lLocalGrad: &value{data: 1.0}, rLocalGrad: &value{data: 1.0}, } } func (v *value) Sub(other *value) *value { return v.Add(other.Neg()) } func (v *value) Div(other *value) *value { return v.Mul(other.Pow(&value{data: -1})) } func (v *value) Mul(other *value) *value { // note the swap here: children are stored as v, other but grads are other, v return &value{ data: v.data * other.data, rChild: v, lChild: other, rLocalGrad: &value{data: other.data}, lLocalGrad: &value{data: v.data}, } } func (v *value) Pow(other *value) *value { return &value{ data: math.Pow(v.data, other.data), rChild: v, rLocalGrad: other.Mul(&value{data: math.Pow(v.data, other.Sub(&value{data: 1}).data)}), } } func (v *value) Neg() *value { return v.Mul(&value{data: -1}) } func (v *value) Log() *value { return &value{ data: math.Log(v.data), rChild: v, rLocalGrad: (&value{data: 1}).Div(v), } } func (v *value) Exp() *value { return &value{ data: math.Exp(v.data), rChild: v, rLocalGrad: &value{data: math.Exp(v.data)}, } } func (v *value) Relu() *value { return &value{ data: max(v.data, 0), rChild: v, rLocalGrad: &value{data: btof(v.data > 0)}, } } func (v *value) Backward() { topo := []*value{} visited := map[*value]struct{}{} var buildTopo func(v *value) buildTopo = func(v *value) { if _, ok := visited[v]; !ok { visited[v] = struct{}{} if v.rChild != nil { buildTopo(v.rChild) } if v.lChild != nil { buildTopo(v.lChild) } topo = append(topo, v) } } buildTopo(v) v.grad = 1.0 for _, v := range slices.Backward(topo) { if v.rChild != nil { v.rChild.grad += v.rLocalGrad.data * v.grad } if v.lChild != nil { v.lChild.grad += v.lLocalGrad.data * v.grad } } } func mkDeepSlice() [][][]*value { a := make([][][]*value, 1, 10) a[0] = make([][]*value, 0, 10) return a } // implement our own weighted random chooser // based on https://cybernetist.com/2019/01/24/random-weighted-draws-in-go/ but without the dependency on gonum func choose(p []*value) int { // Initialization: create the discrete CDF cdf := make([]float64, len(p)) for i, v := range p { if i == 0 { cdf[i] = v.data } else { cdf[i] = cdf[i-1] + v.data } } // Generation: // 1. Generate a uniformly-random value x in the range [0,1) // 2. Using a binary search, find the index of the smallest element in cdf larger than x // multiply the sample with the largest CDF value; easier than normalizing to [0,1) val := rand.Float64() * cdf[len(cdf)-1] // Search returns the smallest index i such that cdf[i] > val return sort.Search(len(cdf), func(i int) bool { return cdf[i] > val }) }