-
Hi, I am looking for ways to load a pre-trained PyTorch model, then continue to fine-tune both the pre-trained model and additional layers in Java (could be JIT, or any other ways to load similar to torch.load_state_dict). I found some discussions around this suggesting that it's possible here pytorch/pytorch#17614 So far with javacpp-presets for pytorch, it seems we can now load the jit modules which means we can run forward passes on pre-trained model. Is there a way to run backward passes as well on the JIT compile torchscript models? I can give the suggestions here for C++ a try pytorch/pytorch#17614 (comment) BTW, thanks @saudet for suggesting this package in another forum 😄. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 8 replies
-
Looks like we'll need to work on mapping that parameters() function, yes. |
Beta Was this translation helpful? Give feedback.
-
It looks like it's now possible to train with higher-level functions BTW, I have already proposed to use JavaCPP for PyTorch in DJL. They are already using it for TensorFlow, but its C++ API is quite limited, so I think it would be nice if DJL or anyone else could provide a user-friendly Java API for PyTorch, which should be a lot easier to do with JavaCPP than by doing it manually with JNI like DJL is doing right now. |
Beta Was this translation helpful? Give feedback.
Looks like we'll need to work on mapping that parameters() function, yes.