Module: Safetensors::Torch
- Defined in:
- lib/safetensors/torch.rb
Constant Summary collapse
- TYPES =
{ "F64" => :float64, "F32" => :float32, "F16" => :float16, "BF16" => :bfloat16, "I64" => :int64, "U64" => :uint64, "I32" => :int32, "U32" => :uint32, "I16" => :int16, "U16" => :uint16, "I8" => :int8, "U8" => :uint8, "BOOL" => :bool, "F8_E4M3" => :float8_e4m3fn, "F8_E5M2" => :float8_e5m2 }
Class Method Summary collapse
- .load(data) ⇒ Object
- .load_file(filename, device: "cpu") ⇒ Object
- .save(tensors, metadata: nil) ⇒ Object
- .save_file(tensors, filename, metadata: nil) ⇒ Object
Class Method Details
.load(data) ⇒ Object
40 41 42 43 |
# File 'lib/safetensors/torch.rb', line 40 def load(data) flat = Safetensors.deserialize(data) _view2torch(flat) end |
.load_file(filename, device: "cpu") ⇒ Object
30 31 32 33 34 35 36 37 38 |
# File 'lib/safetensors/torch.rb', line 30 def load_file(filename, device: "cpu") result = {} Safetensors.safe_open(filename, framework: "torch", device: device) do |f| f.keys.each do |k| result[k] = f.get_tensor(k) end end result end |
.save(tensors, metadata: nil) ⇒ Object
22 23 24 |
# File 'lib/safetensors/torch.rb', line 22 def save(tensors, metadata: nil) Safetensors.serialize(_flatten(tensors), metadata: ) end |
.save_file(tensors, filename, metadata: nil) ⇒ Object
26 27 28 |
# File 'lib/safetensors/torch.rb', line 26 def save_file(tensors, filename, metadata: nil) Safetensors.serialize_file(_flatten(tensors), filename, metadata: ) end |