package neural_nets_lib

  1. Overview
  2. Docs

The row type, shape inference related types and constraint solving.

type kind = [
  1. | `Batch
  2. | `Input
  3. | `Output
]
val equal_kind : kind -> kind -> Base.bool
val compare_kind : kind -> kind -> Base.int
val sexp_of_kind : kind -> Sexplib0.Sexp.t
val kind_of_sexp : Sexplib0.Sexp.t -> kind
val __kind_of_sexp__ : Sexplib0.Sexp.t -> kind
val hash_fold_kind : Base.Hash.state -> kind -> Base.Hash.state
val hash_kind : kind -> Base.Hash.hash_value
val batch : kind
val input : kind
val output : kind
val is_batch : kind -> Base.bool
val is_input : kind -> Base.bool
val is_output : kind -> Base.bool
val batch_val : kind -> Base.unit Base.option
val input_val : kind -> Base.unit Base.option
val output_val : kind -> Base.unit Base.option
module Variants_of_kind : sig ... end
type dim_var
val equal_dim_var : dim_var -> dim_var -> Base.bool
val hash_fold_dim_var : Base.Hash.state -> dim_var -> Base.Hash.state
val hash_dim_var : dim_var -> Base.Hash.hash_value
val compare_dim_var : dim_var -> dim_var -> Base.int
val sexp_of_dim_var : dim_var -> Sexplib0.Sexp.t
val dim_var_of_sexp : Sexplib0.Sexp.t -> dim_var
type dim_cmp
type dim_var_set = (dim_var, dim_cmp) Base.Set.t
val equal_dim_var_set : dim_var_set -> dim_var_set -> Base.bool
val sexp_of_dim_var_set : dim_var_set -> Sexplib0.Sexp.t
val dim_var_set_of_sexp : Sexplib0.Sexp.t -> dim_var_set
type 'a dim_map = (dim_var, 'a, dim_cmp) Base.Map.t
val equal_dim_map : ('a -> 'a -> Base.bool) -> 'a dim_map -> 'a dim_map -> Base.bool
val sexp_of_dim_map : ('a -> Sexplib0.Sexp.t) -> 'a dim_map -> Sexplib0.Sexp.t
val dim_map_of_sexp : (Sexplib0.Sexp.t -> 'a) -> Sexplib0.Sexp.t -> 'a dim_map
val get_var : ?label:Base.string -> Base.unit -> dim_var
val dim_var_set_empty : dim_var_set
val dim_map_empty : 'a dim_map
type dim =
  1. | Var of dim_var
  2. | Dim of {
    1. d : Base.int;
    2. label : Base.string Base.option;
    3. proj_id : Base.int Base.option;
    }

A single axis in a shape.

val equal_dim : dim -> dim -> Base.bool
val hash_fold_dim : Base.Hash.state -> dim -> Base.Hash.state
val hash_dim : dim -> Base.Hash.hash_value
val compare_dim : dim -> dim -> Base.int
val sexp_of_dim : dim -> Sexplib0.Sexp.t
val dim_of_sexp : Sexplib0.Sexp.t -> dim
val var : dim_var -> dim
val dim : d:Base.int -> label:Base.string Base.option -> proj_id:Base.int Base.option -> dim
val is_var : dim -> Base.bool
val is_dim : dim -> Base.bool
val var_val : dim -> dim_var Base.option
val dim_val : dim -> ([ `d of Base.int ] * [ `label of Base.string Base.option ] * [ `proj_id of Base.int Base.option ]) Base.option
module Variants_of_dim : sig ... end
val get_dim : d:Base.int -> ?label:Base.string -> Base.unit -> dim
val dim_to_int_exn : dim -> Base.int
val dim_to_string : [> `Only_labels ] -> dim -> Base.string
type row_id
val sexp_of_row_id : row_id -> Sexplib0.Sexp.t
val row_id_of_sexp : Sexplib0.Sexp.t -> row_id
val compare_row_id : row_id -> row_id -> Base.int
val equal_row_id : row_id -> row_id -> Base.bool
val hash_fold_row_id : Base.Hash.state -> row_id -> Base.Hash.state
val hash_row_id : row_id -> Base.Hash.hash_value
type row_cmp
val row_id : sh_id:Base.int -> kind:kind -> row_id
type row_var
val sexp_of_row_var : row_var -> Sexplib0.Sexp.t
val row_var_of_sexp : Sexplib0.Sexp.t -> row_var
val compare_row_var : row_var -> row_var -> Base.int
val equal_row_var : row_var -> row_var -> Base.bool
val hash_fold_row_var : Base.Hash.state -> row_var -> Base.Hash.state
val hash_row_var : row_var -> Base.Hash.hash_value
val get_row_var : Base.unit -> row_var
type bcast =
  1. | Row_var of {
    1. v : row_var;
    2. beg_dims : dim Base.list;
    }
    (*

    The row can be inferred to have more axes.

    *)
  2. | Broadcastable
    (*

    The shape does not have more axes of this kind, but is "polymorphic".

    *)

A bcast specifies how axes of a single kind in a shape (i.e. the row) can adapt to other shapes.

val equal_bcast : bcast -> bcast -> Base.bool
val hash_fold_bcast : Base.Hash.state -> bcast -> Base.Hash.state
val hash_bcast : bcast -> Base.Hash.hash_value
val compare_bcast : bcast -> bcast -> Base.int
val sexp_of_bcast : bcast -> Sexplib0.Sexp.t
val bcast_of_sexp : Sexplib0.Sexp.t -> bcast
val row_var : v:row_var -> beg_dims:dim Base.list -> bcast
val broadcastable : bcast
val is_row_var : bcast -> Base.bool
val is_broadcastable : bcast -> Base.bool
val row_var_val : bcast -> ([ `v of row_var ] * [ `beg_dims of dim Base.list ]) Base.option
val broadcastable_val : bcast -> Base.unit Base.option
module Variants_of_bcast : sig ... end
type t = {
  1. dims : dim Base.list;
  2. bcast : bcast;
  3. id : row_id;
}
include Ppx_compare_lib.Equal.S with type t := t
val equal : t -> t -> bool
include Ppx_hash_lib.Hashable.S with type t := t
val hash_fold_t : Base.Hash.state -> t -> Base.Hash.state
val hash : t -> Base.Hash.hash_value
include Ppx_compare_lib.Comparable.S with type t := t
val compare : t -> t -> int
include Sexplib0.Sexpable.S with type t := t
val t_of_sexp : Sexplib0.Sexp.t -> t
val sexp_of_t : t -> Sexplib0.Sexp.t
val dims_label_assoc : t -> (Base.string * dim) Base.list
type environment
val sexp_of_environment : environment -> Sexplib0.Sexp.t
val environment_of_sexp : Sexplib0.Sexp.t -> environment
type error_trace = ..
type error_trace +=
  1. | Row_mismatch of t Base.list
  2. | Dim_mismatch of dim Base.list
  3. | Index_mismatch of Arrayjit.Indexing.axis_index Base.list
val sexp_of_error_trace : error_trace -> Base.Sexp.t
exception Shape_error of Base.string * error_trace Base.list
type dim_constraint =
  1. | Unconstrained_dim
  2. | At_least_dim of Base.int
val equal_dim_constraint : dim_constraint -> dim_constraint -> Base.bool
val hash_fold_dim_constraint : Base.Hash.state -> dim_constraint -> Base.Hash.state
val hash_dim_constraint : dim_constraint -> Base.Hash.hash_value
val compare_dim_constraint : dim_constraint -> dim_constraint -> Base.int
val sexp_of_dim_constraint : dim_constraint -> Sexplib0.Sexp.t
val dim_constraint_of_sexp : Sexplib0.Sexp.t -> dim_constraint
val unconstrained_dim : dim_constraint
val at_least_dim : Base.int -> dim_constraint
val is_unconstrained_dim : dim_constraint -> Base.bool
val is_at_least_dim : dim_constraint -> Base.bool
val unconstrained_dim_val : dim_constraint -> Base.unit Base.option
val at_least_dim_val : dim_constraint -> Base.int Base.option
module Variants_of_dim_constraint : sig ... end
type row_constraint =
  1. | Unconstrained
  2. | Total_elems of {
    1. nominator : Base.int;
    2. divided_by : dim_var_set;
    }
    (*

    The row or remainder of a row, inclusive of the further row spec, has this many elements.

    *)
val equal_row_constraint : row_constraint -> row_constraint -> Base.bool
val hash_fold_row_constraint : Base.Hash.state -> row_constraint -> Base.Hash.state
val hash_row_constraint : row_constraint -> Base.Hash.hash_value
val compare_row_constraint : row_constraint -> row_constraint -> Base.int
val sexp_of_row_constraint : row_constraint -> Sexplib0.Sexp.t
val row_constraint_of_sexp : Sexplib0.Sexp.t -> row_constraint
val unconstrained : row_constraint
val total_elems : nominator:Base.int -> divided_by:dim_var_set -> row_constraint
val is_unconstrained : row_constraint -> Base.bool
val is_total_elems : row_constraint -> Base.bool
val unconstrained_val : row_constraint -> Base.unit Base.option
val total_elems_val : row_constraint -> ([ `nominator of Base.int ] * [ `divided_by of dim_var_set ]) Base.option
module Variants_of_row_constraint : sig ... end
type dim_entry =
  1. | Solved_dim of dim
  2. | Bounds_dim of {
    1. cur : dim_var Base.list;
    2. subr : dim_var Base.list;
    3. lub : dim Base.option;
    4. constr : dim_constraint;
    }

An entry implements inequalities cur >= v >= subr and/or an equality v = solved. cur and subr must be sorted using the @@deriving compare comparison.

val sexp_of_dim_entry : dim_entry -> Sexplib0.Sexp.t
val dim_entry_of_sexp : Sexplib0.Sexp.t -> dim_entry
type row_entry =
  1. | Solved_row of t
  2. | Bounds_row of {
    1. cur : row_var Base.list;
    2. subr : row_var Base.list;
    3. lub : t Base.option;
    4. constr : row_constraint;
    }
val sexp_of_row_entry : row_entry -> Sexplib0.Sexp.t
val row_entry_of_sexp : Sexplib0.Sexp.t -> row_entry
type constraint_ =
  1. | Dim_eq of {
    1. d1 : dim;
    2. d2 : dim;
    }
  2. | Row_eq of {
    1. r1 : t;
    2. r2 : t;
    }
  3. | Dim_ineq of {
    1. cur : dim;
    2. subr : dim;
    }
  4. | Row_ineq of {
    1. cur : t;
    2. subr : t;
    }
  5. | Dim_constr of {
    1. d : dim;
    2. constr : dim_constraint;
    }
  6. | Row_constr of {
    1. r : t;
    2. constr : row_constraint;
    }
  7. | Terminal_dim of dim
  8. | Terminal_row of t
val compare_constraint_ : constraint_ -> constraint_ -> Base.int
val equal_constraint_ : constraint_ -> constraint_ -> Base.bool
val sexp_of_constraint_ : constraint_ -> Sexplib0.Sexp.t
val constraint__of_sexp : Sexplib0.Sexp.t -> constraint_
val dim_eq : d1:dim -> d2:dim -> constraint_
val row_eq : r1:t -> r2:t -> constraint_
val dim_ineq : cur:dim -> subr:dim -> constraint_
val row_ineq : cur:t -> subr:t -> constraint_
val dim_constr : d:dim -> constr:dim_constraint -> constraint_
val row_constr : r:t -> constr:row_constraint -> constraint_
val terminal_dim : dim -> constraint_
val terminal_row : t -> constraint_
val is_dim_eq : constraint_ -> Base.bool
val is_row_eq : constraint_ -> Base.bool
val is_dim_ineq : constraint_ -> Base.bool
val is_row_ineq : constraint_ -> Base.bool
val is_dim_constr : constraint_ -> Base.bool
val is_row_constr : constraint_ -> Base.bool
val is_terminal_dim : constraint_ -> Base.bool
val is_terminal_row : constraint_ -> Base.bool
val dim_eq_val : constraint_ -> ([ `d1 of dim ] * [ `d2 of dim ]) Base.option
val row_eq_val : constraint_ -> ([ `r1 of t ] * [ `r2 of t ]) Base.option
val dim_ineq_val : constraint_ -> ([ `cur of dim ] * [ `subr of dim ]) Base.option
val row_ineq_val : constraint_ -> ([ `cur of t ] * [ `subr of t ]) Base.option
val dim_constr_val : constraint_ -> ([ `d of dim ] * [ `constr of dim_constraint ]) Base.option
val row_constr_val : constraint_ -> ([ `r of t ] * [ `constr of row_constraint ]) Base.option
val terminal_dim_val : constraint_ -> dim Base.option
val terminal_row_val : constraint_ -> t Base.option
module Variants_of_constraint_ : sig ... end
type stage =
  1. | Stage1
  2. | Stage2
  3. | Stage3
  4. | Stage4
  5. | Stage5
  6. | Stage6
  7. | Stage7
val sexp_of_stage : stage -> Sexplib0.Sexp.t
val stage_of_sexp : Sexplib0.Sexp.t -> stage
val equal_stage : stage -> stage -> Base.bool
val compare_stage : stage -> stage -> Base.int
val subst_row : environment -> t -> t
val unify_row : stage:stage -> (t * t) -> environment -> constraint_ Base.list * environment
val empty_env : environment
val eliminate_variables : environment -> t -> constraint_ Base.list
val solve_inequalities : stage:stage -> constraint_ Base.list -> environment -> constraint_ Base.list * environment
val row_to_labels : environment -> t -> Base.string Base.array
type proj
val compare_proj : proj -> proj -> Base.int
val equal_proj : proj -> proj -> Base.bool
val sexp_of_proj : proj -> Sexplib0.Sexp.t
val proj_of_sexp : Sexplib0.Sexp.t -> proj
type proj_env
val sexp_of_proj_env : proj_env -> Sexplib0.Sexp.t
val proj_env_of_sexp : Sexplib0.Sexp.t -> proj_env
val fresh_row_proj : t -> t
type proj_equation =
  1. | Proj_eq of proj * proj
    (*

    Two projections are the same, e.g. two axes share the same iterator.

    *)
  2. | Iterated of proj
    (*

    The projection needs to be an iterator even if an axis is not matched with another axis, e.g. for broadcasted-to axes of a tensor assigned a constant.

    *)
val compare_proj_equation : proj_equation -> proj_equation -> Base.int
val equal_proj_equation : proj_equation -> proj_equation -> Base.bool
val sexp_of_proj_equation : proj_equation -> Sexplib0.Sexp.t
val proj_equation_of_sexp : Sexplib0.Sexp.t -> proj_equation
val solve_proj_equations : proj_equation Base.list -> proj_env
val get_proj_index : proj_env -> dim -> Arrayjit.Indexing.axis_index
val get_product_proj : proj_env -> dim -> (Base.int * Base.int) Base.option
val proj_to_iterator : proj_env -> Base.int -> Arrayjit.Indexing.symbol
OCaml

Innovation. Community. Security.