checkpoint
This commit is contained in:
147
microgopt.go
Normal file
147
microgopt.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package microgopt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
func btof(b bool) float64 {
|
||||
if b {
|
||||
return 1.0
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
func Run(docs []string) {
|
||||
// remove leading and trailing whitespace in documents
|
||||
for i := range docs {
|
||||
docs[i] = strings.TrimSpace(docs[i])
|
||||
}
|
||||
fmt.Printf("num docs: %d", len(docs))
|
||||
|
||||
// construct the vocabulary from the documents: an ordered list of all characters in the dataset,
|
||||
// plus a "Beginning Of Sequence" (BOS) token
|
||||
set := map[rune]struct{}{}
|
||||
for _, doc := range docs {
|
||||
for _, c := range doc {
|
||||
set[c] = struct{}{}
|
||||
}
|
||||
}
|
||||
uchars := slices.Sorted(maps.Keys(set))
|
||||
// BOS := len(uchars)
|
||||
vocabSize := len(uchars) + 1
|
||||
fmt.Printf("vocab size: %d", vocabSize)
|
||||
}
|
||||
|
||||
type value struct {
|
||||
data float64
|
||||
grad float64 // implicitly 0 to start
|
||||
children []*value
|
||||
localGrads []*value
|
||||
}
|
||||
|
||||
func (v *value) toKey() string {
|
||||
k := fmt.Sprintf("%+v", v)
|
||||
fmt.Println(k)
|
||||
return k
|
||||
}
|
||||
|
||||
func (v *value) Add(other *value) *value {
|
||||
return &value{
|
||||
data: v.data + other.data,
|
||||
children: []*value{v, other},
|
||||
localGrads: []*value{{data: 1}, {data: 1}},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *value) Sub(other *value) *value {
|
||||
return v.Add(other.Neg())
|
||||
}
|
||||
|
||||
func (v *value) Div(other *value) *value {
|
||||
return v.Mul(other.Pow(&value{data: -1}))
|
||||
}
|
||||
|
||||
func (v *value) Mul(other *value) *value {
|
||||
return &value{
|
||||
data: v.data * other.data,
|
||||
children: []*value{v, other},
|
||||
localGrads: []*value{
|
||||
{data: other.data},
|
||||
{data: v.data},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *value) Pow(other *value) *value {
|
||||
return &value{
|
||||
data: math.Pow(v.data, other.data),
|
||||
children: []*value{v},
|
||||
localGrads: []*value{
|
||||
other.Mul(&value{data: math.Pow(v.data, other.Sub(&value{data: 1}).data)}),
|
||||
}}
|
||||
}
|
||||
|
||||
func (v *value) Neg() *value {
|
||||
return v.Mul(&value{data: -1})
|
||||
}
|
||||
|
||||
func (v *value) Log() *value {
|
||||
return &value{
|
||||
data: math.Log(v.data),
|
||||
children: []*value{v},
|
||||
localGrads: []*value{
|
||||
(&value{data: 1}).Div(v),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *value) Exp() *value {
|
||||
return &value{
|
||||
data: math.Exp(v.data),
|
||||
children: []*value{v},
|
||||
localGrads: []*value{
|
||||
{data: math.Exp(v.data)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *value) Relu() *value {
|
||||
return &value{
|
||||
data: max(v.data, 0),
|
||||
children: []*value{v},
|
||||
localGrads: []*value{
|
||||
{data: btof(v.data > 0)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *value) Backward() {
|
||||
topo := []*value{}
|
||||
visited := map[string]struct{}{}
|
||||
|
||||
var buildTopo func(v *value)
|
||||
buildTopo = func(v *value) {
|
||||
k := v.toKey()
|
||||
if _, ok := visited[k]; !ok {
|
||||
visited[k] = struct{}{}
|
||||
for _, child := range v.children {
|
||||
buildTopo(child)
|
||||
}
|
||||
topo = append(topo, v)
|
||||
}
|
||||
}
|
||||
buildTopo(v)
|
||||
spew.Dump(topo)
|
||||
v.grad = 1
|
||||
for _, v := range slices.Backward(topo) {
|
||||
for i := range v.children {
|
||||
v.children[i].grad += v.localGrads[i].data * v.grad
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user