File size: 7,712 Bytes
eb310cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
package com.example.open.diffusion;
import androidx.appcompat.app.AppCompatActivity;
import androidx.appcompat.widget.AppCompatSpinner;
import android.app.ProgressDialog;
import android.graphics.Bitmap;
import android.os.Bundle;
import android.text.TextUtils;
import android.view.View;
import android.widget.EditText;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import com.example.open.diffusion.core.UNet;
import com.example.open.diffusion.core.tokenizer.EngTokenizer;
import com.example.open.diffusion.core.tokenizer.TextTokenizer;
import java.io.File;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import ai.onnxruntime.OnnxTensor;
public class MainActivity extends AppCompatActivity {
private final ExecutorService exec = Executors.newCachedThreadPool();
private final int[] resolution = {192, 256, 320, 384, 448, 512};
private ImageView mImageView;
private TextView mMsgView;
private EditText mGuidanceView;
private EditText mStepView;
private EditText mPromptView;
private EditText mNetPromptView;
private AppCompatSpinner mWidthSpinner;
private AppCompatSpinner mHeightSpinner;
private ProgressDialog progressDialog;
private EditText mSeedView;
private UNet uNet;
private TextTokenizer tokenizer;
private boolean isWorking = false;
@Override
protected void onDestroy() {
super.onDestroy();
try {
uNet.close();
tokenizer.close();
}catch (Exception e){
e.printStackTrace();
}
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mImageView = findViewById(R.id.image);
mMsgView = findViewById(R.id.msg);
mGuidanceView = findViewById(R.id.guidance);
mStepView = findViewById(R.id.step);
mPromptView = findViewById(R.id.prompt);
mWidthSpinner = findViewById(R.id.width);
mHeightSpinner = findViewById(R.id.height);
mNetPromptView = findViewById(R.id.neg_prompt);
mSeedView = findViewById(R.id.seed);
mWidthSpinner.setSelection(3);
mHeightSpinner.setSelection(3);
progressDialog = new ProgressDialog(MainActivity.this);
uNet = new UNet(this, Device.CPU);
tokenizer = new EngTokenizer(this);
uNet.setCallback(new UNet.Callback() {
@Override
public void onStep(int maxStep, int step) {
runOnUiThread(new MyRunnable() {
@Override
public void run() {
mMsgView.setText(String.format("%d / %d", step + 1, maxStep));
}
});
}
@Override
public void onBuildImage(int status, Bitmap bitmap) {
runOnUiThread(new MyRunnable() {
@Override
public void run() {
if (bitmap != null) mImageView.setImageBitmap(bitmap);
}
});
}
@Override
public void onComplete() {
runOnUiThread(new MyRunnable() {
@Override
public void run() {
mMsgView.setText("已完成");
}
});
}
@Override
public void onStop() {
}
});
findViewById(R.id.copy).setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
progressDialog.show();
exec.execute(new MyRunnable() {
@Override
public void run() {
try {
FileUtils.copyAssets(getAssets(), "model", new File(PathManager.getAsssetOutputPath(MainActivity.this)));
}catch (Exception e){
e.printStackTrace();
}finally {
runOnUiThread(new Runnable() {
@Override
public void run() {
progressDialog.dismiss();
}
});
}
}
});
}
});
findViewById(R.id.generate).setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
try {
if (isWorking) return;
isWorking = true;
mMsgView.setText("初始化. . .");
exec.execute(createRunnable());
}catch (Exception e){
e.printStackTrace();
}
}
});
}
private MyRunnable createRunnable(){
final String guidanceText = mGuidanceView.getText().toString();
final String stepText = mStepView.getText().toString();
final String prompt = mPromptView.getText().toString();
final String negPrompt = mNetPromptView.getText().toString();
final String seedText = mSeedView.getText().toString();
final int num_inference_steps = TextUtils.isEmpty(stepText) ? 8 : Integer.parseInt(stepText);
final double guidance_scale = TextUtils.isEmpty(guidanceText) ? 7.5f : Float.valueOf(guidanceText);
final long seed = TextUtils.isEmpty(seedText) ? 0 : Long.parseLong(seedText);
UNet.WIDTH = resolution[mWidthSpinner.getSelectedItemPosition()];
UNet.HEIGHT = resolution[mHeightSpinner.getSelectedItemPosition()];
return new MyRunnable() {
@Override
public void run() {
try {
tokenizer.init();
int batch_size = 1;
int[] textTokenized = tokenizer.encoder(prompt);
int[] negTokenized = tokenizer.createUncondInput(negPrompt);
OnnxTensor textPromptEmbeddings = tokenizer.tensor(textTokenized);
OnnxTensor uncondEmbedding = tokenizer.tensor(negTokenized);
float[][][] textEmbeddingArray = new float[2][tokenizer.getMaxLength()][768];
float[] textPromptEmbeddingArray = textPromptEmbeddings.getFloatBuffer().array();
float[] uncondEmbeddingArray = uncondEmbedding.getFloatBuffer().array();
for (int i = 0; i < textPromptEmbeddingArray.length; i++)
{
textEmbeddingArray[0][i / 768][i % 768] = uncondEmbeddingArray[i];
textEmbeddingArray[1][i / 768][i % 768] = textPromptEmbeddingArray[i];
}
OnnxTensor textEmbeddings = OnnxTensor.createTensor(App.ENVIRONMENT, textEmbeddingArray);
tokenizer.close();
uNet.init();
uNet.inference(seed, num_inference_steps, textEmbeddings, guidance_scale, batch_size, UNet.WIDTH, UNet.HEIGHT);
}catch (Exception e){
runOnUiThread(new Runnable() {
@Override
public void run() {
mMsgView.setText("Error");
}
});
e.printStackTrace();
}finally {
isWorking = false;
}
}
};
}
} |