Skip to content

Commit

Permalink
Merge pull request #5 from viam-labs/model-ext-check
Browse files Browse the repository at this point in the history
check model extension for onnx
  • Loading branch information
bhaney authored Jul 18, 2024
2 parents 14768e4 + 00c5a18 commit d40174c
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
9 changes: 5 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ jobs:
- name: Install dependencies
run: go mod download

- name: Run unit tests
run: make test

- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.59
args: --disable errcheck --timeout 10m
args: --timeout 10m

- name: Run unit tests
run: make test

4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ MOD_OS := $(shell uname -s)
test:
go test
lint:
golangci-lint run --disable errcheck
golangci-lint run
module.tar.gz:
ifeq ($(MOD_OS),Darwin)
ifeq ($(MOD_ARCH),x86_64)
Expand Down Expand Up @@ -60,4 +60,4 @@ third_party/onnx-android-$(SO_ARCH).so: onnxruntime-android-$(ONNX_VERSION).aar
cp jni/$(SO_ARCH)/libonnxruntime.so $@

bundle-droid-$(SO_ARCH).tar.gz: module third_party/onnx-android-$(SO_ARCH).so
tar -czf $@ $^
tar -czf $@ $^
2 changes: 1 addition & 1 deletion cmd/module/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func realMain() error {
if err != nil {
return err
}
defer onnxClose(ctx)
defer onnxClose(ctx) //nolint:errcheck
<-ctx.Done()
return nil
}
Expand Down
10 changes: 8 additions & 2 deletions onnx_cpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package onnx_cpu

import (
"context"
"path"
"runtime"
"strings"

Expand Down Expand Up @@ -48,9 +49,14 @@ type Config struct {
}

// Validate makes sure that the required model path is not empty
func (cfg *Config) Validate(path string) ([]string, error) {
func (cfg *Config) Validate(validatePath string) ([]string, error) {
if cfg.ModelPath == "" {
return nil, utils.NewConfigValidationFieldRequiredError(path, "model_path")
return nil, utils.NewConfigValidationFieldRequiredError(validatePath, "model_path")
}
ext := path.Ext(cfg.ModelPath)
if ext != ".onnx" {
base := path.Base(cfg.ModelPath)
return nil, errors.Errorf("model_path filename must end in .onnx. The filename is %s", base)
}
return nil, nil
}
Expand Down
19 changes: 19 additions & 0 deletions onnx_cpu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,22 @@ func TestImageDetection(t *testing.T) {
err = theModel.Close(context.Background())
test.That(t, err, test.ShouldBeNil)
}

func TestValidate(t *testing.T) {
// correct
cfg := &Config{"./test_files/ir_mobilenet.onnx", "/path/to/labels.txt"}
deps, err := cfg.Validate("")
test.That(t, err, test.ShouldBeNil)
test.That(t, deps, test.ShouldBeNil)
// empty
cfg = &Config{"", "/path/to/labels.txt"}
_, err = cfg.Validate("")
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err.Error(), test.ShouldContainSubstring, "model_path")
// incorrect
cfg = &Config{"/path/to/other_model.tflite", "/path/to/labels.txt"}
_, err = cfg.Validate("")
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err.Error(), test.ShouldContainSubstring, "must end in .onnx")

}

0 comments on commit d40174c

Please sign in to comment.