add some tests

This commit is contained in:
2026-03-08 21:29:51 -04:00
parent aabd0087d5
commit 40fd5e99d1
2 changed files with 109 additions and 14 deletions

View File

@@ -39,11 +39,23 @@ func valcmp(a, b *value) int {
}
}
func sum(l []*value) *value {
r := &value{}
for v := range l {
r = r.Add(l[v])
}
return r
}
func Run(docs []string) {
// remove leading and trailing whitespace in documents
for i := range docs {
docs[i] = strings.TrimSpace(docs[i])
}
rand.Shuffle(
len(docs),
func(i, j int) { docs[i], docs[j] = docs[j], docs[i] },
)
fmt.Printf("num docs: %d\n", len(docs))
// construct the vocabulary from the documents: an ordered list of all characters in the dataset,
@@ -104,7 +116,7 @@ func Run(docs []string) {
n := min(blockSize, len(tokens)-1)
// Forward the token sequence through the model, building up the computation graph all the way to the loss
keys, values := mkDeepSlice(), mkDeepSlice()
keys, values := mkDeepSlice(nLayer), mkDeepSlice(nLayer)
losses := []*value{}
for posId := range n {
tokenId, targetId := tokens[posId], tokens[posId+1]
@@ -113,10 +125,7 @@ func Run(docs []string) {
lossT := probs[targetId].Log().Neg()
losses = append(losses, lossT)
}
lossSum := &value{}
for _, l := range losses {
lossSum = lossSum.Add(l)
}
lossSum := sum(losses)
loss := (&value{data: 1 / float64(n)}).Mul(lossSum) // final average loss over the document sequence. May yours be low.
// Backward the loss, calculating the gradients with respect to all model parameters
loss.Backward()
@@ -138,7 +147,7 @@ func Run(docs []string) {
temperature := 0.5 // in (0, 1], control the "creativity" of generated text, low to high
fmt.Println("\n--- inference (new, hallucinated names) ---")
for sampleIdx := range 20 {
keys, values := mkDeepSlice(), mkDeepSlice()
keys, values := mkDeepSlice(nLayer), mkDeepSlice(nLayer)
tokenId := BOS
sample := []rune{}
for posId := range blockSize {
@@ -187,10 +196,7 @@ func softMax(logits []*value) []*value {
for _, val := range logits {
exps = append(exps, val.Sub(maxVal).Exp())
}
total := &value{}
for _, e := range exps {
total = total.Add(e)
}
total := sum(exps)
for i := range exps {
exps[i] = exps[i].Div(total)
}
@@ -222,7 +228,7 @@ func gpt(tokenId int, posId int, keys [][][]*value, values [][][]*value) []*valu
for li := range nLayer {
// 1) Multi-head Attention block
xResidual := slices.Clone(x)
xResidual := x
x = rmsNorm(x)
q := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wq", li)])
k := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wk", li)])
@@ -290,6 +296,10 @@ type value struct {
rLocalGrad *value
}
func (v *value) String() string {
return fmt.Sprintf("%.16f", v.data)
}
func (v *value) Add(other *value) *value {
return &value{
data: v.data + other.data,
@@ -384,9 +394,11 @@ func (v *value) Backward() {
}
}
func mkDeepSlice() [][][]*value {
a := make([][][]*value, 1, 10)
a[0] = make([][]*value, 0, 10)
func mkDeepSlice(size int) [][][]*value {
a := make([][][]*value, size)
for i := range size {
a[i] = make([][]*value, 0)
}
return a
}

View File

@@ -33,3 +33,86 @@ func TestValue(t *testing.T) {
})
}
}
func TestLinear(t *testing.T) {
tests := []struct {
name string
x []*value
w [][]*value
want []*value
}{
{
name: "base case",
x: []*value{{data: 1}, {data: 2}, {data: 3}},
w: [][]*value{{{data: 4}, {data: 5}, {data: 6}}, {{data: 7}, {data: 8}, {data: 9}}},
want: []*value{{data: 32}, {data: 50}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := linear(tt.x, tt.w)
if len(tt.want) != len(got) {
t.Errorf("linear() = %v, want %v", got, tt.want)
}
for i, v := range tt.want {
if v.data != got[i].data {
t.Errorf("linear() = %v, want %v", got, tt.want)
}
}
})
}
}
func TestSoftMax(t *testing.T) {
tests := []struct {
name string
logits []*value
want []*value
}{
{
name: "base case",
logits: []*value{{data: 1}, {data: 2}, {data: 3}},
want: []*value{{data: 0.09003057317038045}, {data: 0.2447284710547976}, {data: 0.6652409557748218}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := softMax(tt.logits)
if len(tt.want) != len(got) {
t.Errorf("softMax() = %v, want %v", got, tt.want)
}
for i, v := range tt.want {
if v.data != got[i].data {
t.Errorf("softMax() = %v, want %v", got, tt.want)
}
}
})
}
}
func TestRmsNorm(t *testing.T) {
tests := []struct {
name string
x []*value
want []*value
}{
{
name: "base case",
x: []*value{{data: 1}, {data: 2}, {data: 3}},
want: []*value{{data: 0.4629095539120195}, {data: 0.925819107824039}, {data: 1.3887286617360584}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := rmsNorm(tt.x)
if len(tt.want) != len(got) {
t.Errorf("rmsNorm() = %v, want %v", got, tt.want)
}
for i, v := range tt.want {
if v.data != got[i].data {
t.Errorf("rmsNorm() = %v, want %v", got, tt.want)
}
}
})
}
}