diff --git a/microgopt.go b/microgopt.go index 5431c4c..ef471ad 100644 --- a/microgopt.go +++ b/microgopt.go @@ -284,22 +284,19 @@ func gpt(tokenId int, posId int, keys [][][]*value, values [][][]*value) []*valu type value struct { data float64 grad float64 // implicitly 0 to start - children []*value - localGrads []*value -} - -// this lets us build a set-like map with our Values. -// If the slices were removed from the struct, that would make this method irrelevant. -func (v *value) toKey() string { - k := fmt.Sprintf("%+v", v) - return k + lChild *value + rChild *value + lLocalGrad *value + rLocalGrad *value } func (v *value) Add(other *value) *value { return &value{ data: v.data + other.data, - children: []*value{v, other}, - localGrads: []*value{{data: 1.0}, {data: 1.0}}, + lChild: v, + rChild: other, + lLocalGrad: &value{data: 1.0}, + rLocalGrad: &value{data: 1.0}, } } @@ -314,22 +311,20 @@ func (v *value) Div(other *value) *value { func (v *value) Mul(other *value) *value { // note the swap here: children are stored as v, other but grads are other, v return &value{ - data: v.data * other.data, - children: []*value{v, other}, - localGrads: []*value{ - {data: other.data}, - {data: v.data}, - }, + data: v.data * other.data, + rChild: v, + lChild: other, + rLocalGrad: &value{data: other.data}, + lLocalGrad: &value{data: v.data}, } } func (v *value) Pow(other *value) *value { return &value{ - data: math.Pow(v.data, other.data), - children: []*value{v}, - localGrads: []*value{ - other.Mul(&value{data: math.Pow(v.data, other.Sub(&value{data: 1}).data)}), - }} + data: math.Pow(v.data, other.data), + rChild: v, + rLocalGrad: other.Mul(&value{data: math.Pow(v.data, other.Sub(&value{data: 1}).data)}), + } } func (v *value) Neg() *value { @@ -338,45 +333,41 @@ func (v *value) Neg() *value { func (v *value) Log() *value { return &value{ - data: math.Log(v.data), - children: []*value{v}, - localGrads: []*value{ - (&value{data: 1}).Div(v), - }, + data: math.Log(v.data), + rChild: v, + rLocalGrad: (&value{data: 1}).Div(v), } } func (v *value) Exp() *value { return &value{ - data: math.Exp(v.data), - children: []*value{v}, - localGrads: []*value{ - {data: math.Exp(v.data)}, - }, + data: math.Exp(v.data), + rChild: v, + rLocalGrad: &value{data: math.Exp(v.data)}, } } func (v *value) Relu() *value { return &value{ - data: max(v.data, 0), - children: []*value{v}, - localGrads: []*value{ - {data: btof(v.data > 0)}, - }, + data: max(v.data, 0), + rChild: v, + rLocalGrad: &value{data: btof(v.data > 0)}, } } func (v *value) Backward() { topo := []*value{} - visited := map[string]struct{}{} + visited := map[*value]struct{}{} var buildTopo func(v *value) buildTopo = func(v *value) { - k := v.toKey() - if _, ok := visited[k]; !ok { - visited[k] = struct{}{} - for _, child := range v.children { - buildTopo(child) + if _, ok := visited[v]; !ok { + visited[v] = struct{}{} + if v.rChild != nil { + buildTopo(v.rChild) + } + if v.lChild != nil { + buildTopo(v.lChild) } topo = append(topo, v) } @@ -384,8 +375,11 @@ func (v *value) Backward() { buildTopo(v) v.grad = 1.0 for _, v := range slices.Backward(topo) { - for i := range v.children { - v.children[i].grad += v.localGrads[i].data * v.grad + if v.rChild != nil { + v.rChild.grad += v.rLocalGrad.data * v.grad + } + if v.lChild != nil { + v.lChild.grad += v.lLocalGrad.data * v.grad } } }