Class: Replicate::Model

Inherits:
Object
  • Object
show all
Defined in:
lib/replicate/model.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(client:, path:) ⇒ Model

Returns a new instance of Model.



5
6
7
8
9
10
# File 'lib/replicate/model.rb', line 5

def initialize(client:, path:) 
  @client = client
  @path = path
  response = client.requests(method: "GET", path: "models/#{path}")
  @version_id = response.body["latest_version"]["id"]
end

Instance Attribute Details

#pathObject (readonly)

Returns the value of attribute path.



3
4
5
# File 'lib/replicate/model.rb', line 3

def path
  @path
end

#version_idObject (readonly)

Returns the value of attribute version_id.



3
4
5
# File 'lib/replicate/model.rb', line 3

def version_id
  @version_id
end

Instance Method Details

#poll(prediction_id:) ⇒ Object



18
19
20
21
22
23
24
25
26
27
28
29
# File 'lib/replicate/model.rb', line 18

def poll(prediction_id:)
  response = @client.requests(method: "GET", path: "predictions/#{prediction_id}")
  puts(response.body["status"])
  if !["succeeded", "failed", "canceled"].include?(response.body["status"])
    sleep(2)
    poll(prediction_id: prediction_id)
  elsif response.body["status"] == "succeeded" 
    return response.body["output"]
  else
    raise response.body["error"]
  end
end

#predict(prompt: "", height: 512, width: 512, image: "") ⇒ Object



12
13
14
15
16
# File 'lib/replicate/model.rb', line 12

def predict(prompt: "", height: 512, width: 512, image: "")
  body = {version: @version_id, input: {prompt: prompt, height: height, width: width, image: image}}
  response = @client.requests(method: "POST", path: "predictions", **body)
  poll(prediction_id: response.body["id"])
end