Overview
Goro is a high-level machine learning library for Go built on Gorgonia. It aims to have the same feel as Keras.
Usage
import (
. "github.com/aunum/goro/pkg/v1/model"
"github.com/aunum/goro/pkg/v1/layer"
)
x := NewInput("x", []int{1, 28, 28})
y := NewInput("y", []int{10})
model, _ := NewSequential("mnist")
model.AddLayers(
layer.Conv2D{Input: 1, Output: 32, Width: 3, Height: 3},
layer.MaxPooling2D{},
layer.Conv2D{Input: 32, Output: 64, Width: 3, Height: 3},
layer.MaxPooling2D{},
layer.Conv2D{Input: 64, Output: 128, Width: 3, Height: 3},
layer.MaxPooling2D{},
layer.Flatten{},
layer.FC{Input: 128 * 3 * 3, Output: 100},
layer.FC{Input: 100, Output: 10, Activation: layer.Softmax},
)
optimizer := g.NewRMSPropSolver()
model.Compile(xi, yi,
WithOptimizer(optimizer),
WithLoss(m.CrossEntropy),
WithBatchSize(100),
)
model.Fit(xTrain, yTrain)
prediction, _ := model.Predict(xTest)
model.FitBatch(xTrainBatch, yTrainBatch)
prediction, _ = model.PredictBatch(xTestBatch)
Examples
See the examples folder for example implementations.
There are many examples in the reinforcement learning library Gold.
Docs
Each package contains a README explaining the usage, also see GoDoc.
Contributing
Please open an MR for any issues or feature requests.
Feel free to ping @pbarker on Gopher slack.
Roadmap