add some tests
This commit is contained in:
@@ -33,3 +33,86 @@ func TestValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLinear(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
x []*value
|
||||
w [][]*value
|
||||
want []*value
|
||||
}{
|
||||
{
|
||||
name: "base case",
|
||||
x: []*value{{data: 1}, {data: 2}, {data: 3}},
|
||||
w: [][]*value{{{data: 4}, {data: 5}, {data: 6}}, {{data: 7}, {data: 8}, {data: 9}}},
|
||||
want: []*value{{data: 32}, {data: 50}},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := linear(tt.x, tt.w)
|
||||
if len(tt.want) != len(got) {
|
||||
t.Errorf("linear() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i, v := range tt.want {
|
||||
if v.data != got[i].data {
|
||||
t.Errorf("linear() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftMax(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logits []*value
|
||||
want []*value
|
||||
}{
|
||||
{
|
||||
name: "base case",
|
||||
logits: []*value{{data: 1}, {data: 2}, {data: 3}},
|
||||
want: []*value{{data: 0.09003057317038045}, {data: 0.2447284710547976}, {data: 0.6652409557748218}},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := softMax(tt.logits)
|
||||
if len(tt.want) != len(got) {
|
||||
t.Errorf("softMax() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i, v := range tt.want {
|
||||
if v.data != got[i].data {
|
||||
t.Errorf("softMax() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRmsNorm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
x []*value
|
||||
want []*value
|
||||
}{
|
||||
{
|
||||
name: "base case",
|
||||
x: []*value{{data: 1}, {data: 2}, {data: 3}},
|
||||
want: []*value{{data: 0.4629095539120195}, {data: 0.925819107824039}, {data: 1.3887286617360584}},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := rmsNorm(tt.x)
|
||||
if len(tt.want) != len(got) {
|
||||
t.Errorf("rmsNorm() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i, v := range tt.want {
|
||||
if v.data != got[i].data {
|
||||
t.Errorf("rmsNorm() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user