Class: ReplicateClient::Training

Inherits:
Object
  • Object
show all
Defined in:
lib/replicate-client/training.rb

Defined Under Namespace

Modules: Status

Constant Summary collapse

INDEX_PATH =
"/trainings"

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(attributes) ⇒ ReplicateClient::Training

Initialize a new training instance.

Parameters:

  • attributes (Hash)

    The attributes of the training.



191
192
193
# File 'lib/replicate-client/training.rb', line 191

def initialize(attributes)
  reset_attributes(attributes)
end

Instance Attribute Details

#completed_atTime?

The timestamp when the training was completed.

Returns:

  • (Time, nil)


154
155
156
# File 'lib/replicate-client/training.rb', line 154

def completed_at
  @completed_at
end

#created_atString

The timestamp when the training was created.

Returns:

  • (String)


149
150
151
# File 'lib/replicate-client/training.rb', line 149

def created_at
  @created_at
end

#errorString?

The error message, if any, encountered during the training process.

Returns:

  • (String, nil)


169
170
171
# File 'lib/replicate-client/training.rb', line 169

def error
  @error
end

#idString

The unique identifier of the training.

Returns:

  • (String)


123
124
125
# File 'lib/replicate-client/training.rb', line 123

def id
  @id
end

#inputHash

The input data provided for the training.

Returns:

  • (Hash)


138
139
140
# File 'lib/replicate-client/training.rb', line 138

def input
  @input
end

#logsString

The logs generated during the training process.

Returns:

  • (String)


164
165
166
# File 'lib/replicate-client/training.rb', line 164

def logs
  @logs
end

#metricsHash?

The metrics generated during the training process.

Returns:

  • (Hash, nil)


184
185
186
# File 'lib/replicate-client/training.rb', line 184

def metrics
  @metrics
end

#model_full_nameString

The full model name in the format “owner/name”.

Returns:

  • (String)


128
129
130
# File 'lib/replicate-client/training.rb', line 128

def model_full_name
  @model_full_name
end

#outputHash?

The output data generated during the training process.

Returns:

  • (Hash, nil)


179
180
181
# File 'lib/replicate-client/training.rb', line 179

def output
  @output
end

#started_atTime?

The timestamp when the training was started.

Returns:

  • (Time, nil)


159
160
161
# File 'lib/replicate-client/training.rb', line 159

def started_at
  @started_at
end

#statusString

The current status of the training. Possible values: “starting”, “processing”, “succeeded”, “failed”, “canceled”.

Returns:

  • (String)


144
145
146
# File 'lib/replicate-client/training.rb', line 144

def status
  @status
end

#urlsHash

URLs related to the training, such as those for retrieving or canceling it.

Returns:

  • (Hash)


174
175
176
# File 'lib/replicate-client/training.rb', line 174

def urls
  @urls
end

#version_idString

The version ID of the model being trained.

Returns:

  • (String)


133
134
135
# File 'lib/replicate-client/training.rb', line 133

def version_id
  @version_id
end

Class Method Details

.auto_paging_each {|ReplicateClient::Training| ... } ⇒ void

This method returns an undefined value.

List all trainings.

Yields:



21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# File 'lib/replicate-client/training.rb', line 21

def auto_paging_each(&block)
  cursor = nil

  loop do
    url_params = cursor ? "?cursor=#{cursor}" : ""
    attributes = ReplicateClient.client.get("#{INDEX_PATH}#{url_params}")

    trainings = attributes["results"].map { |training| new(training) }

    trainings.each(&block)

    cursor = attributes["next"] ? URI.decode_www_form(URI.parse(attributes["next"]).query).to_h["cursor"] : nil
    break if cursor.nil?
  end
end

.build_path(id:) ⇒ String

Build the path for a specific training.

Parameters:

  • id (String)

    The id of the training.

Returns:

  • (String)


115
116
117
# File 'lib/replicate-client/training.rb', line 115

def build_path(id:)
  "#{INDEX_PATH}/#{id}"
end

.cancel!(id) ⇒ void

This method returns an undefined value.

Cancel a training.

Parameters:

  • id (String)

    The id of the training.



105
106
107
108
# File 'lib/replicate-client/training.rb', line 105

def cancel!(id)
  path = "#{build_path(id: id)}/cancel"
  ReplicateClient.client.post(path)
end

.create!(owner:, name:, version:, destination:, input:, webhook_url: nil, webhook_events_filter: nil) ⇒ ReplicateClient::Training

Create a new training.

format.

Parameters:

  • owner (String)

    The owner of the model.

  • name (String)

    The name of the model.

  • version (ReplicateClient::Version, String)

    The version of the model to train.

  • destination (ReplicateClient::Model, String)

    The destination model instance or string in “owner/name”

  • input (Hash)

    The input data for the training.

  • webhook_url (String, nil) (defaults to: nil)

    A URL to receive webhook notifications.

  • webhook_events_filter (Array, nil) (defaults to: nil)

    The events to trigger webhook requests.

Returns:



49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# File 'lib/replicate-client/training.rb', line 49

def create!(owner:, name:, version:, destination:, input:, webhook_url: nil, webhook_events_filter: nil)
  destination_str = destination.is_a?(ReplicateClient::Model) ? destination.full_name : destination
  version_id = version.is_a?(ReplicateClient::Model::Version) ? version.id : version

  path = "/models/#{owner}/#{name}/versions/#{version_id}/trainings"
  body = {
    destination: destination_str,
    input: input,
    webhook: webhook_url || ReplicateClient.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter
  }

  attributes = ReplicateClient.client.post(path, body)
  new(attributes)
end

.create_for_model!(model:, destination:, input:, webhook_url: nil, webhook_events_filter: nil) ⇒ ReplicateClient::Training

Create a new training for a specific model.

Parameters:

  • model (ReplicateClient::Model, String)

    The model instance or a string representing the model ID.

  • destination (ReplicateClient::Model, String)

    The destination model or full name in “owner/name” format.

  • input (Hash)

    The input data for the training.

  • webhook_url (String, nil) (defaults to: nil)

    A URL to receive webhook notifications.

  • webhook_events_filter (Array, nil) (defaults to: nil)

    The events to trigger webhook requests.

Returns:

Raises:

  • (ArgumentError)


74
75
76
77
78
79
80
81
82
83
84
85
86
87
# File 'lib/replicate-client/training.rb', line 74

def create_for_model!(model:, destination:, input:, webhook_url: nil, webhook_events_filter: nil)
  model_instance = model.is_a?(ReplicateClient::Model) ? model : ReplicateClient::Model.find(model)
  raise ArgumentError, "Invalid model" unless model_instance

  create!(
    owner: model_instance.owner,
    name: model_instance.name,
    version: model_instance.version_id,
    destination: destination,
    input: input,
    webhook_url: webhook_url || ReplicateClient.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter
  )
end

.find(id) ⇒ ReplicateClient::Training

Find a training by id.

Parameters:

  • id (String)

    The id of the training.

Returns:



94
95
96
97
98
# File 'lib/replicate-client/training.rb', line 94

def find(id)
  path = build_path(id: id)
  attributes = ReplicateClient.client.get(path)
  new(attributes)
end

Instance Method Details

#cancel!void

This method returns an undefined value.

Cancel the training.



233
234
235
# File 'lib/replicate-client/training.rb', line 233

def cancel!
  ReplicateClient::Training.cancel!(id)
end

#canceled?Boolean

Check if the training was canceled.

Returns:

  • (Boolean)


226
227
228
# File 'lib/replicate-client/training.rb', line 226

def canceled?
  status == Status::CANCELED
end

#failed?Boolean

Check if the training has failed.

Returns:

  • (Boolean)


219
220
221
# File 'lib/replicate-client/training.rb', line 219

def failed?
  status == Status::FAILED
end

#modelReplicateClient::Model

The model instance of the training.



248
249
250
# File 'lib/replicate-client/training.rb', line 248

def model
  @model ||= ReplicateClient::Model.find(model_full_name, version_id: version_id)
end

#processing?Boolean

Check if the training is processing.

Returns:

  • (Boolean)


205
206
207
# File 'lib/replicate-client/training.rb', line 205

def processing?
  status == Status::PROCESSING
end

#reload!void

This method returns an undefined value.

Reload the training.



240
241
242
243
# File 'lib/replicate-client/training.rb', line 240

def reload!
  attributes = ReplicateClient.client.get(Training.build_path(id: id))
  reset_attributes(attributes)
end

#starting?Boolean

Check if the training is starting.

Returns:

  • (Boolean)


198
199
200
# File 'lib/replicate-client/training.rb', line 198

def starting?
  status == Status::STARTING
end

#succeeded?Boolean

Check if the training has succeeded.

Returns:

  • (Boolean)


212
213
214
# File 'lib/replicate-client/training.rb', line 212

def succeeded?
  status == Status::SUCCEEDED
end

#versionReplicateClient::Model::Version

The version instance of the training.



255
256
257
# File 'lib/replicate-client/training.rb', line 255

def version
  @version ||= model.version
end