提到神经网络模型,我们一般都会想到用 Python 来训练和预测,但其实用 GoLang 也可以。
背景 前段时间,总有「每天拍云」的用户反馈「小程序无反应」。排查之后发现,某些机型有时会出现加载模型失败的情况。为了优化用户体验,尽可能地让每一个用户都能成功识云,需要在服务端提供一个传图识云的接口。
然而,问题来了!因为服务端是用 GoLang 开发的(具体可戳 抢救了一个能给云朵分类的微信小程序 ),要怎么样快速提供这个接口呢?
解决思路 遇到这个问题,想到的首个解决方法就是用其它语言专门开发这个接口。Python 作为机器学习领域的一等公民,是首选项;小程序用的 tfjs 加载模型,所以 NodeJS 也是一个可选项。
但因为尝试过程中,环境和依赖安装问题并不顺利,加上我又想到了一句至理名言 —— 如无必要不引入新的实体,所以我把方向转为「用 GoLang 加载模型预测」。
关键步骤 本文以 MacOS 开发环境为例,分享其中的关键步骤。
安装依赖
使用 brew install libtensorflow
安装 TensorFlow
使用 go get github.com/galeone/tfgo
安装 tfgo 库
导出模型 如果是自训练模型,修改模型代码,用以下方式导出模型即可。
1 2 3 import tensorflow as tftf.saved_model.save(model, "/path/to/saved_model" )
如果是用的开源模型,而且模型只提供了 fronzen graph
,可以参考以下代码转成 saved_model
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 import tensorflow as tffrom tensorflow.python.framework import graph_iofrom tensorflow.python.framework.importer import import_graph_deffrom tensorflow.python.saved_model import signature_constantsfrom tensorflow.python.saved_model import tag_constantsfrom tensorflow.python.saved_model import builder as saved_model_builderfrozen_graph_path = '/path/to/frozen_inference_graph.pb' saved_model_dir = '/path/to/saved_model' with tf.io.gfile.GFile(frozen_graph_path, "rb" ) as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) with tf.compat.v1.Session(graph=tf.Graph()) as sess: tf.import_graph_def(graph_def, name="" ) input_tensor = sess.graph.get_tensor_by_name("ImageTensor:0" ) output_tensor = sess.graph.get_tensor_by_name("SemanticPredictions:0" ) builder = saved_model_builder.SavedModelBuilder(saved_model_dir) signature = tf.compat.v1.saved_model.signature_def_utils.build_signature_def( inputs={'input' : tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor)}, outputs={'output' : tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor)}, method_name=signature_constants.PREDICT_METHOD_NAME ) builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map={ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }) builder.save()
模型加载和预测 对于熟悉 TensorFlow 的同学,模型的加载和预测倒是不难理解。如果不熟悉,可以把整个模型理解为一张有向图,图中的每个节点输出都是一个 Tensor(张量)。
而模型的加载过程,就是把这张图加载到内存中。模型的预测,则是提供输入节点的值,以及要计算节点的 ID,然后 TensorFlow 会从输入节点开始,层层计算中间节点的值,直到所有预期的输出节点都计算完毕。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import ( "context" "fmt" tf "github.com/galeone/tensorflow/tensorflow/go" tg "github.com/galeone/tfgo" ) func predict (input *tf.Tensor) { model := tg.LoadModel("/path/to/saved_model" , []string {"serve" }, nil ) execResult := model.Exec([]tf.Output{ model.Op("SemanticPredictions" , 0 ), }, map [tf.Output]*tf.Tensor{ model.Op("ImageTensor" , 0 ): input, }) fmt.Println(execResult[0 ]) }
注意事项 版本一致 导出模型的 Python 代码中,使用的 TensorFlow 的库版本应该尽量和 GoLang 中使用的 TensorFlow 的版本保持一致,否则可能出现异常。
找不到 libtensorflow
在 MacOS 上运行代码时,可能会遇到找不到 libtensorflow
的错误,可以手动指定 libtensorflow
的位置解决:
1 CGO_CFLAGS="-I/opt/homebrew/opt/libtensorflow/include" CGO_LDFLAGS="-L/opt/homebrew/opt/libtensorflow/lib -Wl,-rpath,/opt/homebrew/opt/libtensorflow/lib" gf run main.go
容器化 打包镜像时,也同样需要安装版本一致的 libtensorflow
,可参考以下示例:
1 2 3 4 5 6 7 8 9 FROM golang as builderWORKDIR /app COPY . /app RUN curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.15.0.tar.gz" | tar -C /usr/local -xz \ && ldconfig \ && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o main main.go