add some tests
This commit is contained in:
40
microgopt.go
40
microgopt.go
@@ -39,11 +39,23 @@ func valcmp(a, b *value) int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sum(l []*value) *value {
|
||||||
|
r := &value{}
|
||||||
|
for v := range l {
|
||||||
|
r = r.Add(l[v])
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
func Run(docs []string) {
|
func Run(docs []string) {
|
||||||
// remove leading and trailing whitespace in documents
|
// remove leading and trailing whitespace in documents
|
||||||
for i := range docs {
|
for i := range docs {
|
||||||
docs[i] = strings.TrimSpace(docs[i])
|
docs[i] = strings.TrimSpace(docs[i])
|
||||||
}
|
}
|
||||||
|
rand.Shuffle(
|
||||||
|
len(docs),
|
||||||
|
func(i, j int) { docs[i], docs[j] = docs[j], docs[i] },
|
||||||
|
)
|
||||||
fmt.Printf("num docs: %d\n", len(docs))
|
fmt.Printf("num docs: %d\n", len(docs))
|
||||||
|
|
||||||
// construct the vocabulary from the documents: an ordered list of all characters in the dataset,
|
// construct the vocabulary from the documents: an ordered list of all characters in the dataset,
|
||||||
@@ -104,7 +116,7 @@ func Run(docs []string) {
|
|||||||
n := min(blockSize, len(tokens)-1)
|
n := min(blockSize, len(tokens)-1)
|
||||||
|
|
||||||
// Forward the token sequence through the model, building up the computation graph all the way to the loss
|
// Forward the token sequence through the model, building up the computation graph all the way to the loss
|
||||||
keys, values := mkDeepSlice(), mkDeepSlice()
|
keys, values := mkDeepSlice(nLayer), mkDeepSlice(nLayer)
|
||||||
losses := []*value{}
|
losses := []*value{}
|
||||||
for posId := range n {
|
for posId := range n {
|
||||||
tokenId, targetId := tokens[posId], tokens[posId+1]
|
tokenId, targetId := tokens[posId], tokens[posId+1]
|
||||||
@@ -113,10 +125,7 @@ func Run(docs []string) {
|
|||||||
lossT := probs[targetId].Log().Neg()
|
lossT := probs[targetId].Log().Neg()
|
||||||
losses = append(losses, lossT)
|
losses = append(losses, lossT)
|
||||||
}
|
}
|
||||||
lossSum := &value{}
|
lossSum := sum(losses)
|
||||||
for _, l := range losses {
|
|
||||||
lossSum = lossSum.Add(l)
|
|
||||||
}
|
|
||||||
loss := (&value{data: 1 / float64(n)}).Mul(lossSum) // final average loss over the document sequence. May yours be low.
|
loss := (&value{data: 1 / float64(n)}).Mul(lossSum) // final average loss over the document sequence. May yours be low.
|
||||||
// Backward the loss, calculating the gradients with respect to all model parameters
|
// Backward the loss, calculating the gradients with respect to all model parameters
|
||||||
loss.Backward()
|
loss.Backward()
|
||||||
@@ -138,7 +147,7 @@ func Run(docs []string) {
|
|||||||
temperature := 0.5 // in (0, 1], control the "creativity" of generated text, low to high
|
temperature := 0.5 // in (0, 1], control the "creativity" of generated text, low to high
|
||||||
fmt.Println("\n--- inference (new, hallucinated names) ---")
|
fmt.Println("\n--- inference (new, hallucinated names) ---")
|
||||||
for sampleIdx := range 20 {
|
for sampleIdx := range 20 {
|
||||||
keys, values := mkDeepSlice(), mkDeepSlice()
|
keys, values := mkDeepSlice(nLayer), mkDeepSlice(nLayer)
|
||||||
tokenId := BOS
|
tokenId := BOS
|
||||||
sample := []rune{}
|
sample := []rune{}
|
||||||
for posId := range blockSize {
|
for posId := range blockSize {
|
||||||
@@ -187,10 +196,7 @@ func softMax(logits []*value) []*value {
|
|||||||
for _, val := range logits {
|
for _, val := range logits {
|
||||||
exps = append(exps, val.Sub(maxVal).Exp())
|
exps = append(exps, val.Sub(maxVal).Exp())
|
||||||
}
|
}
|
||||||
total := &value{}
|
total := sum(exps)
|
||||||
for _, e := range exps {
|
|
||||||
total = total.Add(e)
|
|
||||||
}
|
|
||||||
for i := range exps {
|
for i := range exps {
|
||||||
exps[i] = exps[i].Div(total)
|
exps[i] = exps[i].Div(total)
|
||||||
}
|
}
|
||||||
@@ -222,7 +228,7 @@ func gpt(tokenId int, posId int, keys [][][]*value, values [][][]*value) []*valu
|
|||||||
|
|
||||||
for li := range nLayer {
|
for li := range nLayer {
|
||||||
// 1) Multi-head Attention block
|
// 1) Multi-head Attention block
|
||||||
xResidual := slices.Clone(x)
|
xResidual := x
|
||||||
x = rmsNorm(x)
|
x = rmsNorm(x)
|
||||||
q := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wq", li)])
|
q := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wq", li)])
|
||||||
k := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wk", li)])
|
k := linear(x, stateMap[fmt.Sprintf("layer%d.attn_wk", li)])
|
||||||
@@ -290,6 +296,10 @@ type value struct {
|
|||||||
rLocalGrad *value
|
rLocalGrad *value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *value) String() string {
|
||||||
|
return fmt.Sprintf("%.16f", v.data)
|
||||||
|
}
|
||||||
|
|
||||||
func (v *value) Add(other *value) *value {
|
func (v *value) Add(other *value) *value {
|
||||||
return &value{
|
return &value{
|
||||||
data: v.data + other.data,
|
data: v.data + other.data,
|
||||||
@@ -384,9 +394,11 @@ func (v *value) Backward() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mkDeepSlice() [][][]*value {
|
func mkDeepSlice(size int) [][][]*value {
|
||||||
a := make([][][]*value, 1, 10)
|
a := make([][][]*value, size)
|
||||||
a[0] = make([][]*value, 0, 10)
|
for i := range size {
|
||||||
|
a[i] = make([][]*value, 0)
|
||||||
|
}
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,3 +33,86 @@ func TestValue(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLinear(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
x []*value
|
||||||
|
w [][]*value
|
||||||
|
want []*value
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "base case",
|
||||||
|
x: []*value{{data: 1}, {data: 2}, {data: 3}},
|
||||||
|
w: [][]*value{{{data: 4}, {data: 5}, {data: 6}}, {{data: 7}, {data: 8}, {data: 9}}},
|
||||||
|
want: []*value{{data: 32}, {data: 50}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := linear(tt.x, tt.w)
|
||||||
|
if len(tt.want) != len(got) {
|
||||||
|
t.Errorf("linear() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
for i, v := range tt.want {
|
||||||
|
if v.data != got[i].data {
|
||||||
|
t.Errorf("linear() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoftMax(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
logits []*value
|
||||||
|
want []*value
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "base case",
|
||||||
|
logits: []*value{{data: 1}, {data: 2}, {data: 3}},
|
||||||
|
want: []*value{{data: 0.09003057317038045}, {data: 0.2447284710547976}, {data: 0.6652409557748218}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := softMax(tt.logits)
|
||||||
|
if len(tt.want) != len(got) {
|
||||||
|
t.Errorf("softMax() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
for i, v := range tt.want {
|
||||||
|
if v.data != got[i].data {
|
||||||
|
t.Errorf("softMax() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRmsNorm(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
x []*value
|
||||||
|
want []*value
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "base case",
|
||||||
|
x: []*value{{data: 1}, {data: 2}, {data: 3}},
|
||||||
|
want: []*value{{data: 0.4629095539120195}, {data: 0.925819107824039}, {data: 1.3887286617360584}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := rmsNorm(tt.x)
|
||||||
|
if len(tt.want) != len(got) {
|
||||||
|
t.Errorf("rmsNorm() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
for i, v := range tt.want {
|
||||||
|
if v.data != got[i].data {
|
||||||
|
t.Errorf("rmsNorm() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user