Class: Langchain::LLM::Cohere

Inherits:
Base
  • Object
show all
Defined in:
lib/langchain/llm/cohere.rb

Overview

Wrapper around the Cohere API.

Gem requirements:

gem "cohere-ruby", "~> 0.9.6"

Usage:

llm = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])

Constant Summary collapse

DEFAULTS =
{
  temperature: 0.0,
  completion_model_name: "command",
  chat_completion_model_name: "command-r-plus",
  embeddings_model_name: "small",
  dimensions: 1024,
  truncate: "START"
}.freeze

Instance Attribute Summary

Attributes inherited from Base

#client

Instance Method Summary collapse

Methods inherited from Base

#chat_parameters, #default_dimension, #default_dimensions

Methods included from DependencyHelper

#depends_on

Constructor Details

#initialize(api_key:, default_options: {}) ⇒ Cohere

Returns a new instance of Cohere.



23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# File 'lib/langchain/llm/cohere.rb', line 23

def initialize(api_key:, default_options: {})
  depends_on "cohere-ruby", req: "cohere"

  @client = ::Cohere::Client.new(api_key: api_key)
  @defaults = DEFAULTS.merge(default_options)
  chat_parameters.update(
    model: {default: @defaults[:chat_completion_model_name]},
    temperature: {default: @defaults[:temperature]}
  )
  chat_parameters.remap(
    system: :preamble,
    messages: :chat_history,
    stop: :stop_sequences,
    top_k: :k,
    top_p: :p
  )
end

Instance Method Details

#chat(params = {}) ⇒ Langchain::LLM::CohereResponse

Generate a chat completion for given messages

Parameters:

  • params (Hash) (defaults to: {})

    unified chat parmeters from [Langchain::LLM::Parameters::Chat::SCHEMA]

Options Hash (params):

  • :messages (Array<String>)

    Input messages

  • :model (String)

    The model that will complete your prompt

  • :max_tokens (Integer)

    Maximum number of tokens to generate before stopping

  • :stop (Array<String>)

    Custom text sequences that will cause the model to stop generating

  • :stream (Boolean)

    Whether to incrementally stream the response using server-sent events

  • :system (String)

    System prompt

  • :temperature (Float)

    Amount of randomness injected into the response

  • :tools (Array<String>)

    Definitions of tools that the model may use

  • :top_k (Integer)

    Only sample from the top K options for each subsequent token

  • :top_p (Float)

    Use nucleus sampling.

Returns:

Raises:

  • (ArgumentError)


97
98
99
100
101
102
103
104
105
# File 'lib/langchain/llm/cohere.rb', line 97

def chat(params = {})
  raise ArgumentError.new("messages argument is required") if Array(params[:messages]).empty?

  parameters = chat_parameters.to_params(params)

  response = client.chat(**parameters)

  Langchain::LLM::CohereResponse.new(response)
end

#complete(prompt:, **params) ⇒ Langchain::LLM::CohereResponse

Generate a completion for a given prompt

Parameters:

  • prompt (String)

    The prompt to generate a completion for

  • params (:stop_sequences)

Returns:



63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# File 'lib/langchain/llm/cohere.rb', line 63

def complete(prompt:, **params)
  default_params = {
    prompt: prompt,
    temperature: @defaults[:temperature],
    model: @defaults[:completion_model_name],
    truncate: @defaults[:truncate]
  }

  if params[:stop_sequences]
    default_params[:stop_sequences] = params.delete(:stop_sequences)
  end

  default_params.merge!(params)

  default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)

  response = client.generate(**default_params)
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
end

#embed(text:) ⇒ Langchain::LLM::CohereResponse

Generate an embedding for a given text

Parameters:

  • text (String)

    The text to generate an embedding for

Returns:



47
48
49
50
51
52
53
54
# File 'lib/langchain/llm/cohere.rb', line 47

def embed(text:)
  response = client.embed(
    texts: [text],
    model: @defaults[:embeddings_model_name]
  )

  Langchain::LLM::CohereResponse.new response, model: @defaults[:embeddings_model_name]
end

#summarize(text:) ⇒ String

Generate a summary in English for a given text

More parameters available to extend this method with: github.com/andreibondarev/cohere-ruby/blob/0.9.4/lib/cohere/client.rb#L107-L115

Parameters:

  • text (String)

    The text to generate a summary for

Returns:

  • (String)

    The summary



113
114
115
116
# File 'lib/langchain/llm/cohere.rb', line 113

def summarize(text:)
  response = client.summarize(text: text)
  response.dig("summary")
end