package microgopt import ( "testing" ) func TestValue(t *testing.T) { tests := []struct { name string inputs []*value // the two values to work on f func(*value, *value) // the work to do want []float64 // two gradients expected at the end }{ { name: "add and multiply", inputs: []*value{{data: 2.0}, {data: 3.0}}, f: func(a, b *value) { c := a.Mul(b) L := c.Add(a) L.Backward() }, want: []float64{4.0, 2.0}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.f(tt.inputs[0], tt.inputs[1]) for i := range tt.want { if tt.want[i] != tt.inputs[i].grad { t.Errorf("got: %v, want: %v", tt.inputs[i], tt.want[i]) } } }) } } func TestLinear(t *testing.T) { tests := []struct { name string x []*value w [][]*value want []*value }{ { name: "base case", x: []*value{{data: 1}, {data: 2}, {data: 3}}, w: [][]*value{{{data: 4}, {data: 5}, {data: 6}}, {{data: 7}, {data: 8}, {data: 9}}}, want: []*value{{data: 32}, {data: 50}}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := linear(tt.x, tt.w) if len(tt.want) != len(got) { t.Errorf("linear() = %v, want %v", got, tt.want) } for i, v := range tt.want { if v.data != got[i].data { t.Errorf("linear() = %v, want %v", got, tt.want) } } }) } } func TestSoftMax(t *testing.T) { tests := []struct { name string logits []*value want []*value }{ { name: "base case", logits: []*value{{data: 1}, {data: 2}, {data: 3}}, want: []*value{{data: 0.09003057317038045}, {data: 0.2447284710547976}, {data: 0.6652409557748218}}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := softMax(tt.logits) if len(tt.want) != len(got) { t.Errorf("softMax() = %v, want %v", got, tt.want) } for i, v := range tt.want { if v.data != got[i].data { t.Errorf("softMax() = %v, want %v", got, tt.want) } } }) } } func TestRmsNorm(t *testing.T) { tests := []struct { name string x []*value want []*value }{ { name: "base case", x: []*value{{data: 1}, {data: 2}, {data: 3}}, want: []*value{{data: 0.4629095539120195}, {data: 0.925819107824039}, {data: 1.3887286617360584}}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := rmsNorm(tt.x) if len(tt.want) != len(got) { t.Errorf("rmsNorm() = %v, want %v", got, tt.want) } for i, v := range tt.want { if v.data != got[i].data { t.Errorf("rmsNorm() = %v, want %v", got, tt.want) } } }) } }