-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #198 from Idein/add-pow-operator
Add pow operator
- Loading branch information
Showing
5 changed files
with
123 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
all: super_resolution.nnoir | ||
|
||
super_resolution.onnx: | ||
wget -O $@ https://github.com/onnx/models/raw/main/archive/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx | ||
wget -O $@ https://github.com/onnx/models/raw/main/validated/vision/super_resolution/sub_pixel_cnn_2016/model/super-resolution-10.onnx | ||
|
||
super_resolution.nnoir: super_resolution.onnx | ||
onnx2nnoir -o $@ --graph_name torch_jit_export --fix_dimension batch_size=1 $< |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import numpy as np | ||
import onnx | ||
from nnoir.functions import Function, Mul | ||
from numpy.typing import NDArray | ||
|
||
from .utils import * | ||
|
||
|
||
class OpPow(Op): | ||
def __init__(self, node: onnx.NodeProto, *args: Any): | ||
super(OpPow, self).__init__(node, *args) | ||
|
||
def to_function(self, env: Dict[str, NDArray[Any]], constants: Dict[str, NDArray[Any]]) -> List[Function]: | ||
[a, b] = self.node.input | ||
|
||
if b in constants and constants[b] == 2.0: | ||
return [Mul([a, a], list(self.node.output))] | ||
else: | ||
raise UnsupportedONNXOperation(self.node, "unimplemented yet") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import Any, Dict, List | ||
|
||
import numpy as np | ||
import onnx | ||
import pytest | ||
from numpy.typing import NDArray | ||
from onnx import TensorProto | ||
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info | ||
from onnx.numpy_helper import from_array | ||
from util import Base | ||
|
||
info = make_tensor_value_info | ||
|
||
|
||
def test_pow_00() -> None: | ||
""" | ||
Test for y is constant(=2) | ||
""" | ||
|
||
class PowTester(Base): | ||
def __init__(self, inputs: Dict[str, NDArray[Any]], outputs: List[str]): | ||
super().__init__(inputs, outputs) | ||
|
||
def create_onnx(self) -> onnx.ModelProto: | ||
node = make_node("Pow", inputs=["x", "y"], outputs=["z"]) | ||
inputs = [ | ||
info("x", TensorProto.FLOAT, (1, 3, 4, 5)), | ||
] | ||
outputs = [info("z", TensorProto.FLOAT, (1, 3, 4, 5))] | ||
|
||
# constant which dimension is none | ||
const = from_array(np.array(2.0).astype(np.float32), "y") | ||
graph = make_graph([node], "pow_graph", inputs, outputs, initializer=[const]) | ||
model = make_model(graph) | ||
return model | ||
|
||
x = np.random.rand(1, 3, 4, 5).astype(np.float32) | ||
outputs = ["z"] | ||
PowTester({"x": x}, outputs).run() | ||
|
||
|
||
@pytest.mark.xfail() | ||
def test_pow_01() -> None: | ||
""" | ||
Test for y is constant(!=2) | ||
Not inplemented yet | ||
""" | ||
|
||
class PowTester(Base): | ||
def __init__(self, inputs: Dict[str, NDArray[Any]], outputs: List[str]): | ||
super().__init__(inputs, outputs) | ||
|
||
def create_onnx(self) -> onnx.ModelProto: | ||
node = make_node("Pow", inputs=["x", "y"], outputs=["z"]) | ||
inputs = [ | ||
info("x", TensorProto.FLOAT, (1, 3, 4, 5)), | ||
] | ||
outputs = [info("z", TensorProto.FLOAT, (1, 3, 4, 5))] | ||
|
||
# constant which dimension is none | ||
const = from_array(np.array(3.0).astype(np.float32), "y") | ||
graph = make_graph([node], "pow_graph", inputs, outputs, initializer=[const]) | ||
model = make_model(graph) | ||
return model | ||
|
||
x = np.random.rand(1, 3, 4, 5).astype(np.float32) | ||
outputs = ["z"] | ||
PowTester({"x": x}, outputs).run() | ||
|
||
|
||
@pytest.mark.xfail() | ||
def test_pow_02() -> None: | ||
""" | ||
Test for y is not constant | ||
Not inplemented yet | ||
""" | ||
shape = (3, 4, 5) | ||
|
||
class PowTester(Base): | ||
def __init__(self, inputs: Dict[str, NDArray[Any]], outputs: List[str]): | ||
super().__init__(inputs, outputs) | ||
|
||
def create_onnx(self) -> onnx.ModelProto: | ||
node = make_node("Pow", inputs=["x", "y"], outputs=["z"]) | ||
inputs = [ | ||
info("x", TensorProto.FLOAT, shape), | ||
info("y", TensorProto.FLOAT, shape), | ||
] | ||
outputs = [info("z", TensorProto.FLOAT, shape)] | ||
graph = make_graph([node], "pow_graph", inputs, outputs) | ||
model = make_model(graph) | ||
return model | ||
|
||
v0 = np.random.rand(*shape).astype(np.float32) | ||
v1 = np.random.rand(*shape).astype(np.float32) | ||
|
||
outputs = ["z"] | ||
PowTester({"x": v0, "y": v1}, outputs).run() |