checkpoint
This commit is contained in:
35
microgopt_test.go
Normal file
35
microgopt_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package microgopt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputs []*value // the two values to work on
|
||||
f func(*value, *value) // the work to do
|
||||
want []float64 // two gradients expected at the end
|
||||
}{
|
||||
{
|
||||
name: "add and multiply",
|
||||
inputs: []*value{{data: 2.0}, {data: 3.0}},
|
||||
f: func(a, b *value) {
|
||||
c := a.Mul(b)
|
||||
L := c.Add(a)
|
||||
L.Backward()
|
||||
},
|
||||
want: []float64{4.0, 2.0},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.f(tt.inputs[0], tt.inputs[1])
|
||||
for i := range tt.want {
|
||||
if tt.want[i] != tt.inputs[i].grad {
|
||||
t.Errorf("got: %v, want: %v", tt.inputs[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user