Skip to content

Commit

Permalink
feat(LLM.php): add abstract call method to LLM class
Browse files Browse the repository at this point in the history
feat(llm-LLaMACpp.php): add example usage of LLaMACpp class with a prompt text

feat(LLaMACpp.php): add LLaMACpp class implementation with constructor, defaultParams, call, llmType, toArray, and getIdentifyingParams methods.
  • Loading branch information
kambo-1st committed Apr 22, 2023
1 parent 9728e02 commit 643c0c5
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ This library is still in development. Use at your own risk!
### Supported features

* Prompt formatting
* LLMs: OpenAI GPT-3, ChatGPT
* LLMs: OpenAI GPT-3, ChatGPT, llama.cpp
* Vector stores: Simple stupid vector store
* Texts splitters: CharacterTextSplitter, RecursiveCharacterTextSplitter
* Embeddings: OpenAI
Expand Down
9 changes: 9 additions & 0 deletions examples/llm-LLaMACpp.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<?php

require __DIR__.'/../vendor/autoload.php';

use Kambo\Langchain\LLMs\LLaMACpp;

$llm = new LLaMACpp(['model_path' => 'path/to/model.bin']);
$text = "What's the best programing language for web?";
echo $llm($text);
41 changes: 41 additions & 0 deletions src/LLMs/LLM.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<?php

namespace Kambo\Langchain\LLMs;

/**
* LLM class that expect subclasses to implement a simpler call method.
* The purpose of this class is to expose a simpler interface for working
* with LLMs, rather than expect the user to implement the full generateResult method.
*/
abstract class LLM extends BaseLLM
{
/**
* @param string $prompt
* @param array $parameters
*
* @return string
*/
abstract public function call(string $prompt, array $parameters = []): string;

/**
* Run the LLM on the given prompts.
* This is the method that should be implemented by subclasses ;-)
*
* @param array $prompts
* @param array|null $stop
*
* @return LLMResult
*/
public function generateResult(array $prompts, array $stop = null): LLMResult
{
$generations = [];
foreach ($prompts as $prompt) {
$result = $this->call($prompt, ['stop' => $stop]);
$generations[] = new Generation(
$result,
);
}

return new LLMResult([$generations], []);
}
}
186 changes: 186 additions & 0 deletions src/LLMs/LLaMACpp.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
<?php

namespace Kambo\Langchain\LLMs;

use Kambo\Langchain\Callbacks\CallbackManager;
use Psr\SimpleCache\CacheInterface;
use Kambo\Langchain\Exceptions\InvalidArgumentException;
use Kambo\LLamaCPPLangchainAdapter\LLamaCPPLangchainAdapter;

use function class_exists;
use function array_merge;

/**
* Wrapper around the llama.cpp model.
*/
final class LLaMACpp extends LLM
{
private $modelPath;
private $nCtx = 1024;
private $nParts = -1;
private $seed = 0;
private $f16Kv = false;
private $logitsAll = false;
private $vocabOnly = false;
private $useMlock = false;
private $nThreads = null;
private $nBatch = 8;
private $suffix = null;
private $maxTokens = 256;
private $temperature = 0.8;
private $topP = 0.95;
private $logprobs = null;
private $echo = false;
private $stop = [];
private $repeatPenalty = 1.1;
private $topK = 40;
private $lastNTokensSize = 64;
private LLamaCPPLangchainAdapter $adapter;

/**
* Possible options:
*
* model_path - The path to the Llama model file.
* n_ctx - Token context window.
* n_parts - Number of parts to split the model into. If -1, the number of parts is automatically determined.
* seed - Seed. If -1, a random seed is used.
* f16_kv - Use half-precision for key/value cache.
* logits_all - Return logits for all tokens, not just the last token.
* vocab_only - Only load the vocabulary, no weights.
* use_mlock - Force system to keep model in RAM.
* n_threads - Number of threads to use. If None, the number of threads is automatically determined.
* n_batch - Number of tokens to process in parallel. Should be a number between 1 and n_ctx.
* suffix - A suffix to append to the generated text. If None, no suffix is appended.
* max_tokens - The maximum number of tokens to generate.
* temperature - The temperature to use for sampling.
* top_p - The top-p value to use for sampling.
* logprobs - The number of logprobs to return. If None, no logprobs are returned.
* echo - Whether to echo the prompt.
* stop - A list of strings to stop generation when encountered.
* repeat_penalty - The penalty to apply to repeated tokens.
* top_k - The top-k value to use for sampling.
* last_n_tokens_size - The number of tokens to look back when applying the repeat_penalty.
*
* @param array $options
* @param CallbackManager|null $callbackManager
* @param CacheInterface|null $cache
* @param LLamaCPPLangchainAdapter|null $adapter
*
* @throws InvalidArgumentException
*/
public function __construct(
array $options = [],
?CallbackManager $callbackManager = null,
?CacheInterface $cache = null,
?LLamaCPPLangchainAdapter $adapter = null,
) {
if (!class_exists(LLamaCPPLangchainAdapter::class)) {
throw new InvalidArgumentException(
'Could not found LLamaCPPLangchainAdapter.
Please install the LLamaCPPLangchainAdapter library to use this model.'
);
}

$this->modelPath = $options['model_path'] ?? $this->modelPath;
$this->nCtx = $options['n_ctx'] ?? $this->nCtx;
$this->nParts = $options['n_parts'] ?? $this->nParts;
$this->seed = $options['seed'] ?? $this->seed;
$this->f16Kv = $options['f16_kv'] ?? $this->f16Kv;
$this->logitsAll = $options['logits_all'] ?? $this->logitsAll;
$this->vocabOnly = $options['vocab_only'] ?? $this->vocabOnly;
$this->useMlock = $options['use_mlock'] ?? $this->useMlock;
$this->nThreads = $options['n_threads'] ?? $this->nThreads;
$this->nBatch = $options['n_batch'] ?? $this->nBatch;
$this->suffix = $options['suffix'] ?? $this->suffix;
$this->maxTokens = $options['max_tokens'] ?? $this->maxTokens;
$this->temperature = $options['temperature'] ?? $this->temperature;
$this->topP = $options['top_p'] ?? $this->topP;
$this->logprobs = $options['logprobs'] ?? $this->logprobs;
$this->echo = $options['echo'] ?? $this->echo;
$this->stop = $options['stop'] ?? $this->stop;
$this->repeatPenalty = $options['repeat_penalty'] ?? $this->repeatPenalty;
$this->topK = $options['top_k'] ?? $this->topK;
$this->lastNTokensSize = $options['last_n_tokens_size'] ?? $this->lastNTokensSize;

parent::__construct($options, $callbackManager, $cache);

if ($adapter === null) {
$adapter = LLamaCPPLangchainAdapter::create(
[
'model_path' => $this->modelPath,
'n_ctx' => $this->nCtx,
'n_parts' => $this->nParts,
'seed' => $this->seed,
'f16_kv' => $this->f16Kv,
'logits_all' => $this->logitsAll,
'vocab_only' => $this->vocabOnly,
'use_mlock' => $this->useMlock,
'n_threads' => $this->nThreads,
]
);
}

$this->adapter = $adapter;
}

/**
* Returns the default parameters for the model.
* @return array
*/
private function defaultParams()
{
return [
'suffix' => $this->suffix,
'max_tokens' => $this->maxTokens,
'temperature' => $this->temperature,
'top_p' => $this->topP,
'logprobs' => $this->logprobs,
'echo' => $this->echo,
'stop_sequences' => $this->stop,
'repeat_penalty' => $this->repeatPenalty,
'top_k' => $this->topK,
];
}

/**
* Call the Llama model and return the output.
*
* @param string $prompt The prompt to use for generation.
* @param mixed $stop A list of strings to stop generation when encountered.
*
* @return string The generated text.
*/
public function call(string $prompt, $stop = null): string
{
$params = $this->defaultParams();

if ($this->stop && $stop !== null) {
throw new InvalidArgumentException('`stop` found in both the input and default params.');
} elseif ($this->stop) {
$params['stop_sequences'] = $this->stop;
} elseif ($stop !== null) {
$params['stop_sequences'] = $stop;
} else {
$params['stop_sequences'] = [];
}

$text = $this->adapter->predict($prompt, $params);

return $text;
}

public function llmType(): string
{
return 'llama.cpp';
}

public function toArray(): array
{
return $this->getIdentifyingParams();
}

public function getIdentifyingParams(): array
{
return array_merge(['model_path' => $this->modelPath], $this->defaultParams());
}
}

0 comments on commit 643c0c5

Please sign in to comment.