Class: ReplicateClient::Prediction

Inherits:
Object
  • Object
show all
Defined in:
lib/replicate-client/prediction.rb

Defined Under Namespace

Modules: Status

Constant Summary collapse

INDEX_PATH =
"/predictions"

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(attributes) ⇒ Prediction

Returns a new instance of Prediction.



198
199
200
# File 'lib/replicate-client/prediction.rb', line 198

def initialize(attributes)
  reset_attributes(attributes)
end

Instance Attribute Details

#completed_atTime

The date the prediction was completed.

Returns:

  • (Time)


181
182
183
# File 'lib/replicate-client/prediction.rb', line 181

def completed_at
  @completed_at
end

#created_atTime

The date the prediction was created.

Returns:

  • (Time)


166
167
168
# File 'lib/replicate-client/prediction.rb', line 166

def created_at
  @created_at
end

#data_removedTime

The date the prediction was removed.

Returns:

  • (Time)


171
172
173
# File 'lib/replicate-client/prediction.rb', line 171

def data_removed
  @data_removed
end

#errorString

The error message for the prediction.

Returns:

  • (String)


156
157
158
# File 'lib/replicate-client/prediction.rb', line 156

def error
  @error
end

#idString

The ID of the prediction.

Returns:

  • (String)


131
132
133
# File 'lib/replicate-client/prediction.rb', line 131

def id
  @id
end

#inputHash

The input data for the prediction.

Returns:

  • (Hash)


146
147
148
# File 'lib/replicate-client/prediction.rb', line 146

def input
  @input
end

#logsString

The logs for the prediction.

Returns:

  • (String)


196
197
198
# File 'lib/replicate-client/prediction.rb', line 196

def logs
  @logs
end

#metricsHash

The metrics for the prediction.

Returns:

  • (Hash)


186
187
188
# File 'lib/replicate-client/prediction.rb', line 186

def metrics
  @metrics
end

#model_nameString

The model used for the prediction.

Returns:

  • (String)


141
142
143
# File 'lib/replicate-client/prediction.rb', line 141

def model_name
  @model_name
end

#outputHash

The output data for the prediction.

Returns:

  • (Hash)


151
152
153
# File 'lib/replicate-client/prediction.rb', line 151

def output
  @output
end

#started_atTime

The date the prediction was started.

Returns:

  • (Time)


176
177
178
# File 'lib/replicate-client/prediction.rb', line 176

def started_at
  @started_at
end

#statusString

The status of the prediction.

Returns:

  • (String)


161
162
163
# File 'lib/replicate-client/prediction.rb', line 161

def status
  @status
end

#urlsHash

The URLs for the prediction.

Returns:

  • (Hash)


191
192
193
# File 'lib/replicate-client/prediction.rb', line 191

def urls
  @urls
end

#version_idString

The version of the model used for the prediction.

Returns:

  • (String)


136
137
138
# File 'lib/replicate-client/prediction.rb', line 136

def version_id
  @version_id
end

Class Method Details

.build_path(id) ⇒ String

Build the path for the prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:

  • (String)


114
115
116
# File 'lib/replicate-client/prediction.rb', line 114

def build_path(id)
  "#{INDEX_PATH}/#{id}"
end

.cancel!(id) ⇒ void

This method returns an undefined value.

Cancel a prediction.

Parameters:

  • id (String)

    The ID of the prediction.



123
124
125
# File 'lib/replicate-client/prediction.rb', line 123

def cancel!(id)
  ReplicateClient.client.post("#{build_path(id)}/cancel")
end

.create!(version:, input:, webhook_url: nil, webhook_events_filter: nil) ⇒ ReplicateClient::Prediction

Create a new prediction for a version.

Parameters:

  • version (String, ReplicateClient::Version)

    The version of the model to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

Returns:



24
25
26
27
28
29
30
31
32
33
34
35
# File 'lib/replicate-client/prediction.rb', line 24

def create!(version:, input:, webhook_url: nil, webhook_events_filter: nil)
  args = {
    version: version.is_a?(Model::Version) ? version.id : version,
    input: input,
    webhook: webhook_url || ReplicateClient.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  prediction = ReplicateClient.client.post(INDEX_PATH, args)

  new(prediction)
end

.create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil) ⇒ ReplicateClient::Prediction

Create a new prediction for a deployment.

Parameters:

  • deployment (String, ReplicateClient::Deployment)

    The deployment to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

Returns:



45
46
47
48
49
50
51
52
53
54
55
# File 'lib/replicate-client/prediction.rb', line 45

def create_for_deployment!(deployment:, input:, webhook_url: nil, webhook_events_filter: nil)
  args = {
    input: input,
    webhook: webhook_url || ReplicateClient.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  prediction = ReplicateClient.client.post("#{deployment.path}#{INDEX_PATH}", args)

  new(prediction)
end

.create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil) ⇒ ReplicateClient::Prediction

Create a new prediction for a model.

Parameters:

  • model (String, ReplicateClient::Model)

    The model to use for the prediction.

  • input (Hash)

    The input data for the prediction.

  • webhook_url (String) (defaults to: nil)

    The URL to send webhook events to.

  • webhook_events_filter (Array<Symbol>) (defaults to: nil)

    The events to send to the webhook.

Returns:



65
66
67
68
69
70
71
72
73
74
75
76
77
# File 'lib/replicate-client/prediction.rb', line 65

def create_for_official_model!(model:, input:, webhook_url: nil, webhook_events_filter: nil)
  model_path = model.is_a?(Model) ? model.path : Model.build_path(**Model.parse_model_name(model))

  args = {
    input: input,
    webhook: webhook_url || ReplicateClient.configuration.webhook_url,
    webhook_events_filter: webhook_events_filter&.map(&:to_s)
  }

  prediction = ReplicateClient.client.post("#{model_path}#{INDEX_PATH}", args)

  new(prediction)
end

.find(id) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:



84
85
86
87
# File 'lib/replicate-client/prediction.rb', line 84

def find(id)
  attributes = ReplicateClient.client.get(build_path(id))
  new(attributes)
end

.find_by(id:) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:



103
104
105
106
107
# File 'lib/replicate-client/prediction.rb', line 103

def find_by(id:)
  find_by!(id: id)
rescue ReplicateClient::NotFoundError
  nil
end

.find_by!(id:) ⇒ ReplicateClient::Prediction

Find a prediction.

Parameters:

  • id (String)

    The ID of the prediction.

Returns:



94
95
96
# File 'lib/replicate-client/prediction.rb', line 94

def find_by!(id:)
  find(id)
end

Instance Method Details

#cancel!void

This method returns an undefined value.

Cancel the prediction.



227
228
229
# File 'lib/replicate-client/prediction.rb', line 227

def cancel!
  Prediction.cancel!(id)
end

#canceled?Boolean

Check if the prediction is canceled.

Returns:

  • (Boolean)


248
249
250
# File 'lib/replicate-client/prediction.rb', line 248

def canceled?
  status == Status::CANCELED
end

#failed?Boolean

Check if the prediction is failed.

Returns:

  • (Boolean)


241
242
243
# File 'lib/replicate-client/prediction.rb', line 241

def failed?
  status == Status::FAILED
end

#modelReplicateClient::Model

The model used for the prediction.



213
214
215
# File 'lib/replicate-client/prediction.rb', line 213

def model
  @model ||= Model.find(@model_name, version_id: @version_id)
end

#processing?Boolean

Check if the prediction is processing.

Returns:

  • (Boolean)


262
263
264
# File 'lib/replicate-client/prediction.rb', line 262

def processing?
  status == Status::PROCESSING
end

#reload!ReplicateClient::Prediction

Reload the prediction.



205
206
207
208
# File 'lib/replicate-client/prediction.rb', line 205

def reload!
  attributes = ReplicateClient.client.get(Prediction.build_path(@id))
  reset_attributes(attributes)
end

#starting?Boolean

Check if the prediction is starting.

Returns:

  • (Boolean)


255
256
257
# File 'lib/replicate-client/prediction.rb', line 255

def starting?
  status == Status::STARTING
end

#succeeded?Boolean

Check if the prediction is succeeded.

Returns:

  • (Boolean)


234
235
236
# File 'lib/replicate-client/prediction.rb', line 234

def succeeded?
  status == Status::SUCCEEDED
end

#versionReplicateClient::Model::Version

The version of the model used for the prediction.



220
221
222
# File 'lib/replicate-client/prediction.rb', line 220

def version
  @version ||= model.version
end