nn-sdk tensorflow(v1 ,v2),onnx,tensorrt,fasttext model infer engine
更多使用参见: https://github.com/ssbuild/nn-sdk
Support development languages c/c++, python, java
Support inference engine tensorflow (v1, v2) onnxruntime tensorrt, fasttext Note: tensorrt 7, 8 passed the test (recommended 8), currently tensorrt only supports linux system
Support multiple subgraphs, support multiple input and multiple output graphs, support pb [tensorflow 1,2] , ckpt [tensorflow] , trt [tensorrt] , fasttext
Support fastertransformer pb [32 precision compared to traditional tf, speed up 1.9x]
pip install tf2pb , model conversion, tf2pb pb model conversion reference: https://pypi.org/project/tf2pb
Model encryption reference test_aes.py, currently supports tensorflow 1 pb model, onnx model, tensorrt fasttext model encryption
Recommended environmentubuntu series centos7 centos8 windows series
python (test_py.py) , c language (test.c) , java language package (nn_sdk.java)
For more usage see: https://github.com/ssbuild/nn-sdk
aes: 加密参考test_aes.py
0: tensorflow
1: onnx
2: tensorrt
3: fasttext
0: fatal
2: error
4: warn
8: info
16: debug
model_type: tensorflow model type
0: pb format
1: ckpt format
fastertransformer算子,模型转换参考tf2pb, 参考 https://pypi.org/project/tf2pb
ConfigProto: tensorflow 显卡配置
device_id: GPU id
engine_major: 推理引擎主版本 tf 0,1 tensorrt 7 或者 8 , fasttext 0
engine_minor: 推理引擎次版本
graph: 多子图配置
node: 例子: tensorflow 1 input_ids:0 , tensorflow 2: input_ids , onnx: input_ids
dtype: 节点的类型根据模型配置,对于c++/java支持 int int64 long longlong float double str
shape: 尺寸维度
2022-07-28 enable tf1 reset_default_graph
2022-06-23 split tensorrt to trt_sdk , optimize onnx engine and modify onnx engine reload bug.
2022-01-21 modify define graph shape contain none and modity demo note,modity a tensorflow 2 infer dtype bug,
remove a deprecationWarning in py>=3.8
2021-12-09 graph data_type 改名 dtype , 除fatal info err debug 增加warn
2021-11-25 修复nn-sdk非主动close, close小bug.
2021-10-21 修复fastext推理向量维度bug
2021-10-16 优化 c++/java接口,可预测动态batch
2021-10-07 增加 fasttext 向量和标签推理
python demo
from nn_sdk import *
config = {
"model_dir": r'/root/model.pb',
"log_level": 8,
"device_id": 0,
"ConfigProto": {
"log_device_placement": False,
"allow_soft_placement": True,
"gpu_options": {"allow_growth": True},
"optimizer_options":{"global_jit_level": 1}
"engine_major": 1,
"is_reset_graph": 1,
"model_type": 0,
'enable': False,
'tags': ['serve'],
'signature_key': 'serving_default',
"fastertransformer":{"enable": False}
'tensorrt': True,
"engine_major": 8,
"engine_minor": 0,
"enable_graph": 0,
'fasttext': {
"engine_major": 0,
"dump_label": 1,
"predict_label": 1,
"graph": [
"input": [
"output": [
seq_length = 256
input_ids = [[1] * seq_length]
input_mask = [[1] * seq_length]
sdk_inf = csdk_object(config)
if sdk_inf.valid():
net_stage = 0
ret, out = sdk_inf.process(net_stage, input_ids,input_mask)
java demo
package nn_sdk;
class nn_buffer_batch{
public float [] input_ids = null;
public float[] pred_ids = null;
public int batch_size = 1;
public nn_buffer_batch(int batch_size_){
this.input_ids = new float[batch_size_ * 10];
this.pred_ids = new float[batch_size_ * 10];
this.batch_size = batch_size_;
for(int i =0;i<1 * 10;i++) {
this.input_ids[i] = 1;
this.pred_ids[i] = 0;
public class nn_sdk {
public native static int sdk_init_cc();
public native static int sdk_uninit_cc();
public native static long sdk_new_cc(String json);
public native static int sdk_delete_cc(long handle);
public native static int sdk_process_cc(long handle, int net_state,int batch_size, nn_buffer_batch buffer);
static {
public static void main(String[] args){
System.out.println("java main...........");
nn_sdk instance = new nn_sdk();
nn_buffer_batch buf = new nn_buffer_batch(2);
String json = "{\r\n"
+ " \"model_dir\": r'model.ckpt',\r\n"
+ " \"aes\":{\r\n"
+ " \"enable\":False,\r\n"
+ " \"key\":bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),\r\n"
+ " \"iv\":bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),\r\n"
+ " },\r\n"
+ " \"log_level\": 4,# fatal 1 , error 2 , info 4 , debug 8\r\n"
+ " 'engine':0, # 0 tensorflow, 1 onnx , 2 tensorrt , 3 fasttext\r\n"
+ " \"device_id\": 0,\r\n"
+ " 'tf':{\r\n"
+ " #tensorflow2 ConfigProto无效\r\n"
+ " \"ConfigProto\": {\r\n"
+ " \"log_device_placement\": False,\r\n"
+ " \"allow_soft_placement\": True,\r\n"
+ " \"gpu_options\": {\r\n"
+ " \"allow_growth\": True\r\n"
+ " },\r\n"
+ " \"graph_options\":{\r\n"
+ " \"optimizer_options\":{\r\n"
+ " \"global_jit_level\": 1\r\n"
+ " }\r\n"
+ " },\r\n"
+ " },\r\n"
+ " \"engine_version\": 1, # tensorflow版本\r\n"
+ " \"model_type\": 1,# 0 pb , 1 ckpt\r\n"
+ " \"saved_model\":{ # 当model_type为pb模型有效, 普通pb enable=False , 如果是saved_model冻结模型 , 则需启用enable并且配置tags\r\n"
+ " 'enable': False, # 是否启用saved_model\r\n"
+ " 'tags': ['serve'],\r\n"
+ " 'signature_key': 'serving_default',\r\n"
+ " },\r\n"
+ " \"fastertransformer\":{\r\n"
+ " \"enable\": False,\r\n"
+ " }\r\n"
+ " },\r\n"
+ " 'onnx':{\r\n"
+ " \"engine_version\": 1,# onnxruntime 版本\r\n"
+ " },\r\n"
+ " 'trt':{\r\n"
+ " \"engine_version\": 8,# tensorrt 版本\r\n"
+ " \"enable_graph\": 0,\r\n"
+ " },\r\n"
+ " 'fasttext': {\r\n"
+ " \"engine_version\": 0,# fasttext主版本\r\n"
+ " \"threshold\":0, # 预测k个标签的阈值\r\n"
+ " \"k\":1, # 预测k个标签\r\n"
+ " \"dump_label\": 1, #输出内部标签,用于上层解码\r\n"
+ " \"predict_label\": 1, #获取预测标签 1 , 获取向量 0\r\n"
+ " },\r\n"
+ " \"graph\": [\r\n"
+ " {\r\n"
+ " # 对于Bert模型 shape [max_batch_size,max_seq_lenth],\r\n"
+ " # 其中max_batch_size 用于c++ java开辟输入输出缓存,输入不得超过max_batch_size,对于python没有作用,取决于上层用户真实输入\r\n"
+ " # python限制max_batch_size 在上层用户输入做\r\n"
+ " # 对于fasttext node 对应name可以任意写,但不能少\r\n"
+ " \"input\": [\r\n"
+ " {\"node\":\"input_ids:0\", \"data_type\":\"float\", \"shape\":[1, 10]},\r\n"
+ " ],\r\n"
+ " \"output\": [\r\n"
+ " {\"node\":\"pred_ids:0\", \"data_type\":\"float\", \"shape\":[1, 10]},\r\n"
+ " ],\r\n"
+ " }\r\n"
+ " ]}";
long handle = sdk_new_cc(json);
System.out.printf("handle: %d\n",handle);
int code = sdk_process_cc(handle,0,buf.batch_size,buf);
System.out.printf("sdk_process_cc %d \n" ,code);
if(code == 0) {
for(int i = 0;i<20 ; i++) {
System.out.printf("%f ",buf.pred_ids[i]);
c/c++ demo
#include <stdio.h>
#include "nn_sdk.h"
int main(){
if (0 != sdk_init_cc()) {
return -1;
printf("配置参考 python.........\n");
const char* json_data = "{\n\
\"model_dir\": \"/root/model.ckpt\",\n\
\"log_level\":8, \n\
\"device_id\":0, \n\
\"tf\":{ \n\
\"ConfigProto\": {\n\
\"gpu_options\":{\"allow_growth\": 1}\n\
\"engine_version\": 1,\n\
\"model_type\":1 ,\n\
\"graph\": [\n\
\"input\": [{\"node\":\"input_ids:0\", \"data_type\":\"float\", \"shape\":[1, 10]}],\n\
\"output\" : [{\"node\":\"pred_ids:0\", \"data_type\":\"float\", \"shape\":[1, 10]}]\n\
printf("%s\n", json_data);
auto handle = sdk_new_cc(json_data);
const int INPUT_NUM = 1;
const int OUTPUT_NUM = 1;
const int M = 1;
const int N = 10;
int *input[INPUT_NUM] = { 0 };
float* result[OUTPUT_NUM] = { 0 };
int element_input_size = sizeof(int);
int element_output_size = sizeof(float);
for (int i = 0; i < OUTPUT_NUM; ++i) {
result[i] = (float*)malloc(M * N * element_output_size);
memset(result[i], 0, M * N * element_output_size);
for(int i =0;i<INPUT_NUM;++i){
input[i] = (int*)malloc(M * N * element_input_size);
memset(input[i], 0, M * N * element_input_size);
for (int j = 0; j < N; ++j) {
input[i][j] = i;
int batch_size = 1;
int code = sdk_process_cc(handle, 0 , batch_size, (void**)input,(void**)result);
if (code == 0) {
for (int i = 0; i < N; ++i) {
printf("%f ", result[0][i]);
for (int i = 0; i < INPUT_NUM; ++i) {
for (int i = 0; i < OUTPUT_NUM; ++i) {
return 0;
# -*- coding: UTF-8 -*-
import sys
from nn_sdk.engine_csdk import sdk_aes_encode_decode
def test_string():
data1 = {
"mode":0,# 0 加密 , 1 解密
"key": bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
"iv": bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
"data": bytes([1,2,3,5,255])
code,encrypt = sdk_aes_encode_decode(data1)
data2 = {
"key": bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
"iv": bytes([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]),
"data": encrypt
code,plain = sdk_aes_encode_decode(data2)
def test_encode_file(in_filename,out_filename):
with open(in_filename,mode='rb') as f:
data = f.read()
if len(data) == 0 :
return -1
data1 = {
"mode": 0, # 0 加密 , 1 解密
"key": bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
"iv": bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
"data": bytes(data)
code, encrypt = sdk_aes_encode_decode(data1)
if code != 0:
return code
with open(out_filename, mode='wb') as f:
return code
def test_decode_file(in_filename,out_filename):
with open(in_filename, mode='rb') as f:
data = f.read()
if len(data) == 0:
return -1
data1 = {
"mode": 1, # 0 加密 , 1 解密
"key": bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
"iv": bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
"data": bytes(data)
code, plain = sdk_aes_encode_decode(data1)
if code != 0:
return code
with open(out_filename, mode='wb') as f:
return code