420 lines
11 KiB
Go
420 lines
11 KiB
Go
package microgopt
|
|
|
|
import (
|
|
"fmt"
|
|
"maps"
|
|
"math"
|
|
"math/rand/v2"
|
|
"slices"
|
|
"sort"
|
|
"strings"
|
|
)
|
|
|
|
// Initialize the parameters, to store the knowledge of the model
|
|
const (
|
|
nLayer = 1 // depth of the transformer neural network (number of layers)
|
|
nEmbd = 16 // width of the network (embedding dimension)
|
|
blockSize = 16 // maximum context length of the attention window (note: the longest name is 15 characters)
|
|
nHead = 4 // number of attention heads
|
|
headDim = nEmbd / nHead // derived dimension of each head
|
|
)
|
|
|
|
var stateMap = map[string][][]*value{}
|
|
|
|
// this type pun just worked in python but go needs to be more explicit
|
|
func btof(b bool) float64 {
|
|
if b {
|
|
return 1.0
|
|
}
|
|
return 0.0
|
|
}
|
|
|
|
func valcmp(a, b *value) int {
|
|
if a.data < b.data {
|
|
return -1
|
|
} else if a.data > b.data {
|
|
return 1
|
|
} else {
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func Run(docs []string) {
|
|
// remove leading and trailing whitespace in documents
|
|
for i := range docs {
|
|
docs[i] = strings.TrimSpace(docs[i])
|
|
}
|
|
fmt.Printf("num docs: %d\n", len(docs))
|
|
|
|
// construct the vocabulary from the documents: an ordered list of all characters in the dataset,
|
|
// plus a "Beginning Of Sequence" (BOS) token
|
|
set := map[rune]struct{}{}
|
|
for _, doc := range docs {
|
|
for _, c := range doc {
|
|
set[c] = struct{}{}
|
|
}
|
|
}
|
|
uchars := slices.Sorted(maps.Keys(set))
|
|
BOS := len(uchars)
|
|
vocabSize := len(uchars) + 1
|
|
fmt.Printf("vocab size: %d\n", vocabSize)
|
|
|
|
// in the python code, at this point, the Value class was created
|
|
// and the global parameters were set up
|
|
|
|
stateMap["wte"] = genMatrix(vocabSize, nEmbd)
|
|
stateMap["wpe"] = genMatrix(blockSize, nEmbd)
|
|
stateMap["lm_head"] = genMatrix(vocabSize, nEmbd)
|
|
for i := range nLayer {
|
|
stateMap[fmt.Sprintf("layer%d.attn_wq", i)] = genMatrix(nEmbd, nEmbd)
|
|
stateMap[fmt.Sprintf("layer%d.attn_wk", i)] = genMatrix(nEmbd, nEmbd)
|
|
stateMap[fmt.Sprintf("layer%d.attn_wv", i)] = genMatrix(nEmbd, nEmbd)
|
|
stateMap[fmt.Sprintf("layer%d.attn_wo", i)] = genMatrix(nEmbd, nEmbd)
|
|
stateMap[fmt.Sprintf("layer%d.mlp_fc1", i)] = genMatrix(4*nEmbd, nEmbd)
|
|
stateMap[fmt.Sprintf("layer%d.mlp_fc2", i)] = genMatrix(nEmbd, 4*nEmbd)
|
|
}
|
|
// flatten params into a single []value
|
|
params := []*value{}
|
|
for _, mat := range stateMap {
|
|
for _, row := range mat {
|
|
for _, p := range row {
|
|
params = append(params, p)
|
|
}
|
|
}
|
|
}
|
|
fmt.Printf("num params: %d\n", len(params))
|
|
|
|
// at this point in the python, linear(), softmax(), rmsnorm(), and gpt() are all defined
|
|
|
|
// "Let there be Adam, the blessed optimizer and its buffers"
|
|
learningRate, beta1, beta2, epsAdam := 0.01, 0.85, 0.99, 1e-8
|
|
m := make([]float64, len(params)) // first moment buffer
|
|
v := make([]float64, len(params)) // second moment buffer
|
|
|
|
// Repeat in sequence
|
|
numSteps := 1000 // number of training steps
|
|
for step := range numSteps {
|
|
// Take single document, tokenize it, surround it with BOS special token on both sides
|
|
doc := docs[step%len(docs)]
|
|
tokens := []int{BOS}
|
|
for _, ch := range doc {
|
|
tokens = append(tokens, slices.Index(uchars, ch))
|
|
}
|
|
tokens = append(tokens, BOS)
|
|
n := min(blockSize, len(tokens)-1)
|
|
|
|
// Forward the token sequence through the model, building up the computation graph all the way to the loss
|
|
keys, values := mkDeepSlice(nLayer), mkDeepSlice(nLayer)
|
|
losses := []*value{}
|
|
for posId := range n {
|
|
tokenId, targetId := tokens[posId], tokens[posId+1]
|
|
logits := gpt(tokenId, posId, keys, values)
|
|
probs := softMax(logits)
|
|
lossT := probs[targetId].Log().Neg()
|
|
losses = append(losses, lossT)
|
|
}
|
|
lossSum := &value{}
|
|
for _, l := range losses {
|
|
lossSum = lossSum.Add(l)
|
|
}
|
|
loss := (&value{data: 1 / float64(n)}).Mul(lossSum) // final average loss over the document sequence. May yours be low.
|
|
// Backward the loss, calculating the gradients with respect to all model parameters
|
|
loss.Backward()
|
|
|
|
// Adam optimizer update: update the model parameters based on the corresponding gradients
|
|
lrt := learningRate * (float64(1) - float64(step)/float64(numSteps))
|
|
for i, p := range params {
|
|
m[i] = beta1*m[i] + (1-beta1)*p.grad
|
|
v[i] = beta2*v[i] + (1-beta2)*math.Pow(p.grad, 2.0)
|
|
m_hat := m[i] / (1 - math.Pow(beta1, float64(step+1)))
|
|
v_hat := v[i] / (1 - math.Pow(beta2, float64(step+1)))
|
|
p.data = p.data - (lrt*m_hat)/(math.Pow(v_hat, 0.5)+epsAdam)
|
|
p.grad = 0.0
|
|
}
|
|
fmt.Printf("step %4d / %4d | loss %.4f\r", step+1, numSteps, loss.data)
|
|
}
|
|
|
|
// Inference: may the model babble back to us
|
|
temperature := 0.5 // in (0, 1], control the "creativity" of generated text, low to high
|
|
fmt.Println("\n--- inference (new, hallucinated names) ---")
|
|
for sampleIdx := range 20 {
|
|
keys, values := mkDeepSlice(nLayer), mkDeepSlice(nLayer)
|
|
tokenId := BOS
|
|
sample := []rune{}
|
|
for posId := range blockSize {
|
|
logits := gpt(tokenId, posId, keys, values)
|
|
probs := make([]*value, len(logits))
|
|
for i, l := range logits {
|
|
probs[i] = l.Div(&value{data: temperature})
|
|
}
|
|
probs = softMax(probs)
|
|
tokenId := choose(probs)
|
|
if tokenId == BOS {
|
|
break
|
|
}
|
|
sample = append(sample, uchars[tokenId])
|
|
}
|
|
fmt.Printf("sample %2d: %s\n", sampleIdx+1, string(sample))
|
|
}
|
|
}
|
|
|
|
func genMatrix(out, in int) [][]*value {
|
|
m := make([][]*value, out)
|
|
for o := range out {
|
|
m[o] = make([]*value, in)
|
|
for i := range in {
|
|
m[o][i] = &value{data: rand.NormFloat64() * 0.08}
|
|
}
|
|
}
|
|
return m
|
|
}
|
|
|
|
func linear(x []*value, w [][]*value) []*value {
|
|
r := []*value{}
|
|
for _, wo := range w {
|
|
s := &value{data: 0.0}
|
|
for i := range wo {
|
|
s = s.Add(wo[i].Mul(x[i]))
|
|
}
|
|
r = append(r, s)
|
|
}
|
|
return r
|
|
}
|
|
|
|
func softMax(logits []*value) []*value {
|
|
maxVal := slices.MaxFunc(logits, valcmp)
|
|
exps := []*value{}
|
|
for _, val := range logits {
|
|
exps = append(exps, val.Sub(maxVal).Exp())
|
|
}
|
|
total := &value{}
|
|
for _, e := range exps {
|
|
total = total.Add(e)
|
|
}
|
|
for i := range exps {
|
|
exps[i] = exps[i].Div(total)
|
|
}
|
|
return exps
|
|
}
|
|
|
|
func rmsNorm(x []*value) []*value {
|
|
ms := &value{}
|
|
for _, xi := range x {
|
|
ms = ms.Add(xi.Mul(xi))
|
|
}
|
|
ms = ms.Div(&value{data: float64(len(x))})
|
|
scale := ms.Add(&value{data: 1e-5}).Pow(&value{data: -0.5})
|
|
for i := range x {
|
|
x[i] = x[i].Mul(scale)
|
|
}
|
|
return x
|
|
}
|
|
|
|
func gpt(tokenId int, posId int, keys [][][]*value, values [][][]*value) []*value {
|
|
tokEmb := stateMap["wte"][tokenId] // token embedding
|
|
posEmb := stateMap["wpe"][posId] // position embedding
|
|
x := []*value{}
|
|
// joint token and position embedding
|
|
for i := range tokEmb {
|
|
x = append(x, tokEmb[i].Add(posEmb[i]))
|
|
}
|
|
x = rmsNorm(x) // note: not redundant due to backward pass via the residual connection
|
|
|
|
for li := range nLayer {
|
|
// 1) Multi-head Attention block
|
|
xResidual := slices.Clone(x)
|
|
x = rmsNorm(x)
|
|
q := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wq", li)])
|
|
k := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wk", li)])
|
|
v := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wv", li)])
|
|
keys[li] = append(keys[li], k)
|
|
values[li] = append(values[li], v)
|
|
xAttn := []*value{}
|
|
// basically, distribute the work over the "attention heads"
|
|
for h := range nHead {
|
|
hs := h * headDim
|
|
q_h := q[hs : hs+headDim]
|
|
k_h := [][]*value{}
|
|
for _, ki := range keys[li] {
|
|
k_h = append(k_h, ki[hs:hs+headDim])
|
|
}
|
|
v_h := [][]*value{}
|
|
for _, vi := range values[li] {
|
|
v_h = append(v_h, vi[hs:hs+headDim])
|
|
}
|
|
attnLogits := []*value{}
|
|
for t := range len(k_h) {
|
|
s := &value{data: 0.0}
|
|
for j := range headDim {
|
|
s = s.Add(q_h[j].Mul(k_h[t][j]))
|
|
}
|
|
attnLogits = append(attnLogits, s.Div(&value{data: math.Pow(float64(headDim), 0.5)}))
|
|
}
|
|
attnWeights := softMax(attnLogits)
|
|
headOut := []*value{}
|
|
for j := range headDim {
|
|
s := &value{data: 0.0}
|
|
for t := range len(v_h) {
|
|
s = s.Add(attnWeights[t].Mul(v_h[t][j]))
|
|
}
|
|
headOut = append(headOut, s)
|
|
}
|
|
xAttn = append(xAttn, headOut...)
|
|
}
|
|
x = linear(xAttn, stateMap[fmt.Sprintf("layer%d.attn_wo", li)])
|
|
for i := range x {
|
|
x[i] = x[i].Add(xResidual[i])
|
|
}
|
|
// 2) MLP block
|
|
xResidual = x
|
|
x = rmsNorm(x)
|
|
x = linear(x, stateMap[fmt.Sprintf("layer%d.mlp_fc1", li)])
|
|
for i := range x {
|
|
x[i] = x[i].Relu()
|
|
}
|
|
x = linear(x, stateMap[fmt.Sprintf("layer%d.mlp_fc2", li)])
|
|
for i := range x {
|
|
x[i] = x[i].Add(xResidual[i])
|
|
}
|
|
}
|
|
logits := linear(x, stateMap["lm_head"])
|
|
return logits
|
|
}
|
|
|
|
type value struct {
|
|
data float64
|
|
grad float64 // implicitly 0 to start
|
|
children []*value
|
|
localGrads []*value
|
|
}
|
|
|
|
// this lets us build a set-like map with our Values.
|
|
// If the slices were removed from the struct, that would make this method irrelevant.
|
|
func (v *value) toKey() string {
|
|
k := fmt.Sprintf("%+v", v)
|
|
return k
|
|
}
|
|
|
|
func (v *value) Add(other *value) *value {
|
|
return &value{
|
|
data: v.data + other.data,
|
|
children: []*value{v, other},
|
|
localGrads: []*value{{data: 1.0}, {data: 1.0}},
|
|
}
|
|
}
|
|
|
|
func (v *value) Sub(other *value) *value {
|
|
return v.Add(other.Neg())
|
|
}
|
|
|
|
func (v *value) Div(other *value) *value {
|
|
return v.Mul(other.Pow(&value{data: -1}))
|
|
}
|
|
|
|
func (v *value) Mul(other *value) *value {
|
|
// note the swap here: children are stored as v, other but grads are other, v
|
|
return &value{
|
|
data: v.data * other.data,
|
|
children: []*value{v, other},
|
|
localGrads: []*value{
|
|
{data: other.data},
|
|
{data: v.data},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (v *value) Pow(other *value) *value {
|
|
return &value{
|
|
data: math.Pow(v.data, other.data),
|
|
children: []*value{v},
|
|
localGrads: []*value{
|
|
other.Mul(&value{data: math.Pow(v.data, other.Sub(&value{data: 1}).data)}),
|
|
}}
|
|
}
|
|
|
|
func (v *value) Neg() *value {
|
|
return v.Mul(&value{data: -1})
|
|
}
|
|
|
|
func (v *value) Log() *value {
|
|
return &value{
|
|
data: math.Log(v.data),
|
|
children: []*value{v},
|
|
localGrads: []*value{
|
|
(&value{data: 1}).Div(v),
|
|
},
|
|
}
|
|
}
|
|
|
|
func (v *value) Exp() *value {
|
|
return &value{
|
|
data: math.Exp(v.data),
|
|
children: []*value{v},
|
|
localGrads: []*value{
|
|
{data: math.Exp(v.data)},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (v *value) Relu() *value {
|
|
return &value{
|
|
data: max(v.data, 0),
|
|
children: []*value{v},
|
|
localGrads: []*value{
|
|
{data: btof(v.data > 0)},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (v *value) Backward() {
|
|
topo := []*value{}
|
|
visited := map[string]struct{}{}
|
|
|
|
var buildTopo func(v *value)
|
|
buildTopo = func(v *value) {
|
|
k := v.toKey()
|
|
if _, ok := visited[k]; !ok {
|
|
visited[k] = struct{}{}
|
|
for _, child := range v.children {
|
|
buildTopo(child)
|
|
}
|
|
topo = append(topo, v)
|
|
}
|
|
}
|
|
buildTopo(v)
|
|
v.grad = 1.0
|
|
for _, v := range slices.Backward(topo) {
|
|
for i := range v.children {
|
|
v.children[i].grad += v.localGrads[i].data * v.grad
|
|
}
|
|
}
|
|
}
|
|
|
|
func mkDeepSlice(size int) [][][]*value {
|
|
a := make([][][]*value, 1, 10)
|
|
a[0] = make([][]*value, 0, 10)
|
|
return a
|
|
}
|
|
|
|
// implement our own weighted random chooser
|
|
// based on https://cybernetist.com/2019/01/24/random-weighted-draws-in-go/ but without the dependency on gonum
|
|
func choose(p []*value) int {
|
|
// Initialization: create the discrete CDF
|
|
cdf := make([]float64, len(p))
|
|
for i, v := range p {
|
|
if i == 0 {
|
|
cdf[i] = v.data
|
|
} else {
|
|
cdf[i] = cdf[i-1] + v.data
|
|
}
|
|
}
|
|
// Generation:
|
|
// 1. Generate a uniformly-random value x in the range [0,1)
|
|
// 2. Using a binary search, find the index of the smallest element in cdf larger than x
|
|
var val float64
|
|
// multiply the sample with the largest CDF value; easier than normalizing to [0,1)
|
|
val = rand.Float64() * cdf[len(cdf)-1]
|
|
// Search returns the smallest index i such that cdf[i] > val
|
|
return sort.Search(len(cdf), func(i int) bool { return cdf[i] > val })
|
|
}
|