如何使用 GoLang 加载神经网络模型

提到神经网络模型,我们一般都会想到用 Python 来训练和预测,但其实用 GoLang 也可以。

背景

前段时间,总有「每天拍云」的用户反馈「小程序无反应」。排查之后发现,某些机型有时会出现加载模型失败的情况。为了优化用户体验,尽可能地让每一个用户都能成功识云,需要在服务端提供一个传图识云的接口。

然而,问题来了!因为服务端是用 GoLang 开发的(具体可戳 抢救了一个能给云朵分类的微信小程序),要怎么样快速提供这个接口呢?

解决思路

遇到这个问题,想到的首个解决方法就是用其它语言专门开发这个接口。Python 作为机器学习领域的一等公民,是首选项;小程序用的 tfjs 加载模型,所以 NodeJS 也是一个可选项。

但因为尝试过程中,环境和依赖安装问题并不顺利,加上我又想到了一句至理名言 —— 如无必要不引入新的实体,所以我把方向转为「用 GoLang 加载模型预测」。

关键步骤

本文以 MacOS 开发环境为例,分享其中的关键步骤。

安装依赖

  • 使用 brew install libtensorflow 安装 TensorFlow
  • 使用 go get github.com/galeone/tfgo 安装 tfgo

导出模型

如果是自训练模型,修改模型代码,用以下方式导出模型即可。

1
2
3
import tensorflow as tf

tf.saved_model.save(model, "/path/to/saved_model")

如果是用的开源模型,而且模型只提供了 fronzen graph,可以参考以下代码转成 saved_model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.python.framework.importer import import_graph_def
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import builder as saved_model_builder

# 设定输入路径和输出路径
frozen_graph_path = '/path/to/frozen_inference_graph.pb'
saved_model_dir = '/path/to/saved_model'

# 加载 frozen graph
with tf.io.gfile.GFile(frozen_graph_path, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())

# 创建新的图
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
tf.import_graph_def(graph_def, name="")

# 获取模型输入和输出节点(根据你的模型修改节点名称)
# 如果不清楚输出输出节点,也可以通过 https://netron.app/ 加载模型文件,在线预览节点
input_tensor = sess.graph.get_tensor_by_name("ImageTensor:0")
output_tensor = sess.graph.get_tensor_by_name("SemanticPredictions:0")

# 构建 SavedModel builder
builder = saved_model_builder.SavedModelBuilder(saved_model_dir)

# 定义模型的 Signature (用于推理时的输入输出映射)
signature = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor)},
outputs={'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor)},
method_name=signature_constants.PREDICT_METHOD_NAME
)

# 保存模型
builder.add_meta_graph_and_variables(
sess,
[tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
})
builder.save()

模型加载和预测

对于熟悉 TensorFlow 的同学,模型的加载和预测倒是不难理解。如果不熟悉,可以把整个模型理解为一张有向图,图中的每个节点输出都是一个 Tensor(张量)。

而模型的加载过程,就是把这张图加载到内存中。模型的预测,则是提供输入节点的值,以及要计算节点的 ID,然后 TensorFlow 会从输入节点开始,层层计算中间节点的值,直到所有预期的输出节点都计算完毕。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import (
"context"
"fmt"
tf "github.com/galeone/tensorflow/tensorflow/go"
tg "github.com/galeone/tfgo"
)

func predict(input *tf.Tensor) {
// 加载模型
model := tg.LoadModel("/path/to/saved_model", []string{"serve"}, nil)

// 定义要计算的节点
execResult := model.Exec([]tf.Output{
model.Op("SemanticPredictions", 0),
}, map[tf.Output]*tf.Tensor{
model.Op("ImageTensor", 0): input,
})

// 执行结果是一个数组
fmt.Println(execResult[0])
}

注意事项

版本一致

导出模型的 Python 代码中,使用的 TensorFlow 的库版本应该尽量和 GoLang 中使用的 TensorFlow 的版本保持一致,否则可能出现异常。

找不到 libtensorflow

在 MacOS 上运行代码时,可能会遇到找不到 libtensorflow 的错误,可以手动指定 libtensorflow 的位置解决:

1
CGO_CFLAGS="-I/opt/homebrew/opt/libtensorflow/include" CGO_LDFLAGS="-L/opt/homebrew/opt/libtensorflow/lib -Wl,-rpath,/opt/homebrew/opt/libtensorflow/lib" gf run main.go

容器化

打包镜像时,也同样需要安装版本一致的 libtensorflow,可参考以下示例:

1
2
3
4
5
6
7
8
9
FROM golang as builder

WORKDIR /app

COPY . /app

RUN curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.15.0.tar.gz" | tar -C /usr/local -xz \
&& ldconfig \
&& CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -o main main.go