发布时间:2024-08-06 12:01
在学习tensorflow时的案例应用时,往往可以看到,应用用到了一些模型,其下载过程写在了download.gradle文件中
其下载过程如下
task downloadPosenetModel(type: DownloadUrlTask) {
def modelPosenetDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite"
doFirst {
println "Downloading ${modelPosenetDownloadUrl}"
}
sourceUrl = "${modelPosenetDownloadUrl}"
target = file("src/main/assets/posenet.tflite")
}
task downloadMovenetLightningModel(type: DownloadUrlTask) {
def modelMovenetLightningDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/float16/4?lite-format=tflite"
doFirst {
println "Downloading ${modelMovenetLightningDownloadUrl}"
}
sourceUrl = "${modelMovenetLightningDownloadUrl}"
target = file("src/main/assets/movenet_lightning.tflite")
}
task downloadMovenetThunderModel(type: DownloadUrlTask) {
def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/float16/4?lite-format=tflite"
doFirst {
println "Downloading ${modelMovenetThunderDownloadUrl}"
}
sourceUrl = "${modelMovenetThunderDownloadUrl}"
target = file("src/main/assets/movenet_thunder.tflite")
}
task downloadMovenetMultiPoseModel(type: DownloadUrlTask) {
def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/multipose/lightning/tflite/float16/1?lite-format=tflite"
doFirst {
println "Downloading ${modelMovenetThunderDownloadUrl}"
}
sourceUrl = "${modelMovenetThunderDownloadUrl}"
target = file("src/main/assets/movenet_multipose_fp16.tflite")
}
task downloadPoseClassifierModel(type: DownloadUrlTask) {
def modelPoseClassifierDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/pose_classifier/yoga_classifier.tflite"
doFirst {
println "Downloading ${modelPoseClassifierDownloadUrl}"
}
sourceUrl = "${modelPoseClassifierDownloadUrl}"
target = file("src/main/assets/classifier.tflite")
}
task downloadModel {
dependsOn downloadPosenetModel
dependsOn downloadMovenetLightningModel
dependsOn downloadMovenetThunderModel
dependsOn downloadPoseClassifierModel
dependsOn downloadMovenetMultiPoseModel
}
class DownloadUrlTask extends DefaultTask {
@Input
String sourceUrl
@OutputFile
File target
@TaskAction
void download() {
ant.get(src: sourceUrl, dest: target)
}
}
preBuild.dependsOn downloadModel
另外,还有一些应用的tflite模型下载文件如下
task downloadModelFile(type: Download) {
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/digit_classifier/mnist.tflite'
dest project.ext.ASSET_DIR + '/mnist.tflite'
overwrite false
}
tasks.whenTaskAdded { task ->
if (task.name == 'assembleDebug') {
task.dependsOn 'downloadModelFile'
}
if (task.name == 'assembleRelease') {
task.dependsOn 'downloadModelFile'
}
}
在国内,往往无法直接通过tfhub.dev和googleapis网站下载到这些模型,因此需要先在外网下载。然后将其放置在app/src/assets文件目录下,并为其改名。
通常建议,将第一种情况写法改为如上第二种模式。
如果不行,则需要清空掉 第1种模式 download.gradle中内容即可。