package torch
Install
Dune Dependency
Authors
Maintainers
Sources
sha256=ccd9ef3b630bdc7c41e363e71d8ecb86c316460cbf79afe67546c6ff22c19da4
Description
The ocamltorch project provides some OCaml bindings for the Torch library. This brings to OCaml NumPylike tensor computations with GPU acceleration and tapebased automatic differentiation.
Published: 14 Jun 2023
README
ocamltorch
ocamltorch provides some ocaml bindings for the PyTorch tensor library. This brings to OCaml NumPylike tensor computations with GPU acceleration and tapebased automatic differentiation.
These bindings use the PyTorch C++ API and are mostly automatically generated. The current GitHub tip and the opam package v0.7 corresponds to PyTorch v1.13.0.
On Linux note that you will need the PyTorch version using the appropriate cxx11 abi depending on your g++ version. cpu version, cuda 11.6 version.
Opam Installation
The opam package can be installed using the following command. This automatically installs the CPU version of libtorch.
opam install torch
You can then compile some sample code, see some instructions below. ocamltorch can also be used in interactive mode via utop or ocamljupyter.
Here is a sample utop session.
Build a Simple Example
To build a first torch program, create a file example.ml
with the following content.
open Torch
let () =
let tensor = Tensor.randn [ 4; 2 ] in
Tensor.print tensor
Then create a dune
file with the following content:
(executables
(names example)
(libraries torch))
Run dune exec example.exe
to compile the program and run it!
Alternatively you can first compile the code via dune build example.exe
then run the executable _build/default/example.exe
(note that building the bytecode target example.bc
may not work on macos).
Tutorials and Examples
Some more advanced applications from external repos:
An OCaml port of minidalle by Arulselvan Madhavan.
Natural Language Processing models based on BERT can be found in the ocamlbert repo.
Sample Code
Below is an example of a linear model trained on the MNIST dataset (full code).
(* Create two tensors to store model weights. *)
let ws = Tensor.zeros [image_dim; label_count] ~requires_grad:true in
let bs = Tensor.zeros [label_count] ~requires_grad:true in
let model xs = Tensor.(mm xs ws + bs) in
for index = 1 to 100 do
(* Compute the crossentropy loss. *)
let loss =
Tensor.cross_entropy_for_logits (model train_images) ~targets:train_labels
in
Tensor.backward loss;
(* Apply gradient descent, disable gradient tracking for these. *)
Tensor.(no_grad (fun () >
ws = grad ws * f learning_rate;
bs = grad bs * f learning_rate));
(* Compute the validation error. *)
let test_accuracy =
Tensor.(argmax ~dim:(1) (model test_images) = test_labels)
> Tensor.to_kind ~kind:(T Float)
> Tensor.sum
> Tensor.float_value
> fun sum > sum /. test_samples
in
printf "%d %f %.2f%%\n%!" index (Tensor.float_value loss) (100. *. test_accuracy);
done
A simplified version of charrnn illustrating character level language modeling using Recurrent Neural Networks.
Neural Style Transfer applies the style of an image to the content of another image. This uses some deep Convolutional Neural Network.
Models and Weights
Various pretrained computer vision models are implemented in the vision library. The weight files can be downloaded at the following links:
ResNet18 weights.
ResNet34 weights.
ResNet50 weights.
ResNet101 weights.
ResNet152 weights.
DenseNet121 weights.
DenseNet161 weights.
DenseNet169 weights.
SqueezeNet 1.0 weights.
SqueezeNet 1.1 weights.
VGG13 weights.
VGG16 weights.
AlexNet weights.
Inceptionv3 weights.
MobileNetv2 weights.
EfficientNet b0 weights, b1 weights, b2 weights, b3 weights, b4 weights.
Running the pretrained models on some sample images can the easily be done via the following commands.
dune exec examples/pretrained/predict.exe path/to/resnet18.ot images/tiger.jpg
Acknowledgements
Many thanks to @LaurentMazare for the original work of ocamltorch.
Dependencies (14)

libtorch
>= "1.13.0" & < "1.14.0"

ocamlcompilerlibs
>= "v0.11.0"
 duneconfigurator

dune
>= "2.0.0"
 ctypesforeign

ctypes
>= "0.18.0"

stdio
>= "v0.16" & < "v0.17"

ppx_jane
>= "v0.16" & < "v0.17"

ppx_inline_test
>= "v0.16" & < "v0.17"

ppx_expect
>= "v0.16" & < "v0.17"

ppx_bench
>= "v0.16" & < "v0.17"

core
>= "v0.16" & < "v0.17"

base
>= "v0.16" & < "v0.17"

ocaml
>= "4.14"
Dev Dependencies
None
Used by
None
Conflicts
None