Files
microgopt/microgopt_test.go
2026-03-08 21:29:51 -04:00

119 lines
2.6 KiB
Go

package microgopt
import (
"testing"
)
func TestValue(t *testing.T) {
tests := []struct {
name string
inputs []*value // the two values to work on
f func(*value, *value) // the work to do
want []float64 // two gradients expected at the end
}{
{
name: "add and multiply",
inputs: []*value{{data: 2.0}, {data: 3.0}},
f: func(a, b *value) {
c := a.Mul(b)
L := c.Add(a)
L.Backward()
},
want: []float64{4.0, 2.0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.f(tt.inputs[0], tt.inputs[1])
for i := range tt.want {
if tt.want[i] != tt.inputs[i].grad {
t.Errorf("got: %v, want: %v", tt.inputs[i], tt.want[i])
}
}
})
}
}
func TestLinear(t *testing.T) {
tests := []struct {
name string
x []*value
w [][]*value
want []*value
}{
{
name: "base case",
x: []*value{{data: 1}, {data: 2}, {data: 3}},
w: [][]*value{{{data: 4}, {data: 5}, {data: 6}}, {{data: 7}, {data: 8}, {data: 9}}},
want: []*value{{data: 32}, {data: 50}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := linear(tt.x, tt.w)
if len(tt.want) != len(got) {
t.Errorf("linear() = %v, want %v", got, tt.want)
}
for i, v := range tt.want {
if v.data != got[i].data {
t.Errorf("linear() = %v, want %v", got, tt.want)
}
}
})
}
}
func TestSoftMax(t *testing.T) {
tests := []struct {
name string
logits []*value
want []*value
}{
{
name: "base case",
logits: []*value{{data: 1}, {data: 2}, {data: 3}},
want: []*value{{data: 0.09003057317038045}, {data: 0.2447284710547976}, {data: 0.6652409557748218}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := softMax(tt.logits)
if len(tt.want) != len(got) {
t.Errorf("softMax() = %v, want %v", got, tt.want)
}
for i, v := range tt.want {
if v.data != got[i].data {
t.Errorf("softMax() = %v, want %v", got, tt.want)
}
}
})
}
}
func TestRmsNorm(t *testing.T) {
tests := []struct {
name string
x []*value
want []*value
}{
{
name: "base case",
x: []*value{{data: 1}, {data: 2}, {data: 3}},
want: []*value{{data: 0.4629095539120195}, {data: 0.925819107824039}, {data: 1.3887286617360584}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := rmsNorm(tt.x)
if len(tt.want) != len(got) {
t.Errorf("rmsNorm() = %v, want %v", got, tt.want)
}
for i, v := range tt.want {
if v.data != got[i].data {
t.Errorf("rmsNorm() = %v, want %v", got, tt.want)
}
}
})
}
}