1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
| #构建模型
import torch
class Dense(torch.nn.Module):
def __init__(self):
super(Tian, self).__init__()
#in_features=196608, out_features=10
self.linear = torch.nn.Linear(768, 64, bias=False)
# self.activate = torch.nn.modules.linear.Identity
self.activation_function=torch.nn.Identity()
def forward(self, input):
output = self.linear(input)
output = self.activation_function(output)
return output
model = Dense().to(device)
import torch
from transformers import AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("pytorch_model.bin"))
model.eval()
with torch.no_grad():
result = model( torch.tensor([[0.0]*768]).to(device)).to(device)
output_path = './modelDenseBatch.onnx' # onnx模型输出到哪里去
output_names = None
dummy_input = torch.tensor([[0.0]*768]).to(device)
torch.onnx.export(model,#model being run
args=dummy_input, # model input (or a tuple for multiple inputs)
f=output_path,
input_names=['ids'],
opset_version=11 ,
dynamic_axes={'ids':{0:'batch_size',1:'token_size'}}# 动态batch
)
|
1
2
3
4
5
6
7
8
9
10
11
12
| X = torch.tensor([[0.0]*6534]).to(device)
torch.onnx.export(model, X,
ONNXMODEL_PATH,
input_names=['input'],
output_names=['output1','output2','output3'],
opset_version=10,
dynamic_axes={'input':{0:'batch_size'},
'output1':{0:'batch_size'},
'output2':{0:'batch_size'},
'output3':{0:'batch_size'}
}
)
|
1
2
3
4
5
6
7
8
9
10
| # 确认是否使用GPU
print(onnxruntime.get_device())
import onnxruntime
import numpy as np
MODEL = onnxruntime.InferenceSession("./../modelDenseBatch.onnx", providers=['CUDAExecutionProvider'])
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
return tensor.cpu().numpy()
outputs = MODEL.run(None,{"ids":to_numpy(ort_result)})
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
| <!-- https://mvnrepository.com/artifact/org.bytedeco/onnxruntime-platform -->
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>onnxruntime-platform</artifactId>
<version>1.8.1-1.5.6</version>
</dependency>
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class MainCLI {
public static void main(String[] args) {
String modelFile = "src/main/resources/model.onnx";
System.out.println("Loading model from " + modelFile);
OrtEnvironment env = OrtEnvironment.getEnvironment();
try {
OrtSession session = env.createSession(modelFile, new OrtSession.SessionOptions());
// 输入
float[][] inputArr = new float[1][103];
OnnxTensor t1 = OnnxTensor.createTensor(env, inputArr);
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input.1", t1);
// 执行
OrtSession.Result results = session.run(inputs);
System.out.println("output (" + results.size() + "): " + results.get(0).getInfo());
float[][] labels = (float[][]) results.get(0).getValue();
System.out.println("output value: " + Arrays.toString(labels[0]));
} catch (OrtException e) {
e.printStackTrace();
try {
env.close();
} catch (OrtException oe) {
oe.printStackTrace();
}}}}
|