Pytorch-Mobile-Android(3) 部署自己模型
目录
1.用torch.jit.script转torchscript,不要用torch.jit.trace
一、例子:
1.用torch.jit.script转torchscript,不要用torch.jit.trace
理由见:【Pytorch部署】TorchScript - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/135911580
import vision_transformer
from torch.utils.mobile_optimizer import optimize_for_mobile
import torch
model_vit = vision_transformer._create_vision_transformer('vit_tiny_patch16_384')
model_vit = model_vit.eval()
example = torch.rand(1, 3, 384, 384)
traced_script_module = torch.jit.script(model_vit, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(r"D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\vit2.pt")
会报错UserWarning: `optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead
warnings.warn(不清楚原因,但是不影响运行。
2.将图像的width和height用PIL改成符合的输入
from PIL import Image
img = Image.open(r'D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\image.jpg')
# img = img.resize((384, 384), Image.BILINEAR)
# img.save(r'D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\image.jpg')
print(img.size)
3.套用pytorch-mobile官网的代码运行即可
package org.pytorch.helloworld;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
int width = bitmap.getWidth();
int height = bitmap.getHeight();
Log.e("width", String.format("width %d ", width)); //总时间
Log.e("height", String.format("height %d ", height));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = LiteModuleLoader.load(assetFilePath(this, "vit2.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
4.Lite version update
随着更新,官网上的build.gradle所导入的dependencies版本太低,导入一些model时会报错,这时候,只需打开build.gradle文件,鼠标放在dependencies下的引用的包,就会出现更新的提示,更新即可。