Class: Langchain::LLM::Replicate

Inherits:
Base
  • Object
show all
Defined in:
lib/langchain/llm/replicate.rb

Overview

Wrapper around Replicate.com LLM provider

Gem requirements:

gem "replicate-ruby", "~> 0.2.2"

Usage:

llm = Langchain::LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])

Constant Summary collapse

DEFAULTS =
{
  # TODO: Figure out how to send the temperature to the API
  temperature: 0.01, # Minimum accepted value
  # TODO: Design the interface to pass and use different models
  completion_model: "replicate/vicuna-13b",
  embedding_model: "creatorrr/all-mpnet-base-v2",
  dimensions: 384
}.freeze

Instance Attribute Summary

Attributes inherited from Base

#client, #defaults

Instance Method Summary collapse

Methods inherited from Base

#chat, #chat_parameters, #default_dimension, #default_dimensions

Methods included from DependencyHelper

#depends_on

Constructor Details

#initialize(api_key:, default_options: {}) ⇒ Replicate

Intialize the Replicate LLM

Parameters:

  • api_key (String)

    The API key to use



27
28
29
30
31
32
33
34
35
36
# File 'lib/langchain/llm/replicate.rb', line 27

def initialize(api_key:, default_options: {})
  depends_on "replicate-ruby", req: "replicate"

  ::Replicate.configure do |config|
    config.api_token = api_key
  end

  @client = ::Replicate.client
  @defaults = DEFAULTS.merge(default_options)
end

Instance Method Details

#complete(prompt:, **params) ⇒ Langchain::LLM::ReplicateResponse

Generate a completion for a given prompt

Parameters:

  • prompt (String)

    The prompt to generate a completion for

Returns:



61
62
63
64
65
66
67
68
69
70
# File 'lib/langchain/llm/replicate.rb', line 61

def complete(prompt:, **params)
  response = completion_model.predict(prompt: prompt)

  until response.finished?
    response.refetch
    sleep(0.1)
  end

  Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:completion_model])
end

#embed(text:) ⇒ Langchain::LLM::ReplicateResponse Also known as: generate_embedding

Generate an embedding for a given text

Parameters:

  • text (String)

    The text to generate an embedding for

Returns:



44
45
46
47
48
49
50
51
52
53
# File 'lib/langchain/llm/replicate.rb', line 44

def embed(text:)
  response = embeddings_model.predict(input: text)

  until response.finished?
    response.refetch
    sleep(0.1)
  end

  Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:embedding_model])
end

#summarize(text:) ⇒ String

Generate a summary for a given text

Parameters:

  • text (String)

    The text to generate a summary for

Returns:

  • (String)

    The summary



78
79
80
81
82
83
84
85
86
87
88
89
90
# File 'lib/langchain/llm/replicate.rb', line 78

def summarize(text:)
  prompt_template = Langchain::Prompt.load_from_path(
    file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
  )
  prompt = prompt_template.format(text: text)

  complete(
    prompt: prompt,
    temperature: @defaults[:temperature],
    # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
    max_tokens: 2048
  )
end