Add Options to generate and chat
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::types::common::Options;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
@@ -17,8 +19,16 @@ pub struct Message {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub messages: Vec<Message>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
|
||||
/// Runtime options that control text generation
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<Options>,
|
||||
}
|
||||
|
||||
impl ChatRequest {
|
||||
@@ -46,6 +56,7 @@ impl ChatRequestBuilder {
|
||||
model: model.into(),
|
||||
messages: vec![],
|
||||
stream: None,
|
||||
options: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -55,6 +66,11 @@ impl ChatRequestBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn options(mut self, options: Options) -> Self {
|
||||
self.chat_request.options = Some(options);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> ChatRequest {
|
||||
self.chat_request
|
||||
}
|
||||
|
||||
@@ -23,3 +23,103 @@ pub enum ThinkLevel {
|
||||
Medium,
|
||||
Low,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct Options {
|
||||
/// Random seed used for reproducible outputs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<u64>,
|
||||
|
||||
/// Controls randomness in generation (higher = more random)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Limits next token selection to the K most likely
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
|
||||
/// Cumulative probability threshold for nucleus sampling
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Minimum probability threshold for token selection
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub min_p: Option<f32>,
|
||||
|
||||
/// Stop sequences that will halt generation
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Stop>,
|
||||
|
||||
/// Context length size (number of tokens)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub num_ctx: Option<u32>,
|
||||
|
||||
/// Maximum number of tokens to generate
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub num_predict: Option<u32>,
|
||||
}
|
||||
|
||||
impl Options {
|
||||
pub fn builder() -> OptionsBuilder {
|
||||
OptionsBuilder {
|
||||
options: Options::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OptionsBuilder {
|
||||
options: Options,
|
||||
}
|
||||
|
||||
impl OptionsBuilder {
|
||||
pub fn seed(mut self, seed: u64) -> Self {
|
||||
self.options.seed = Some(seed);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||
self.options.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_k(mut self, top_k: u32) -> Self {
|
||||
self.options.top_k = Some(top_k);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||
self.options.top_p = Some(top_p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn min_p(mut self, min_p: f32) -> Self {
|
||||
self.options.min_p = Some(min_p);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop(mut self, stop: Stop) -> Self {
|
||||
self.options.stop = Some(stop);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn num_ctx(mut self, num_ctx: u32) -> Self {
|
||||
self.options.num_ctx = Some(num_ctx);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn num_predict(mut self, num_predict: u32) -> Self {
|
||||
self.options.num_predict = Some(num_predict);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Options {
|
||||
self.options
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Stop {
|
||||
Single(String),
|
||||
Multiple(Vec<String>),
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::types::common::Think;
|
||||
use crate::types::common::{Options, Think};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateRequest {
|
||||
@@ -38,6 +38,10 @@ pub struct GenerateRequest {
|
||||
/// (true/false) or a string ("high", "medium", "low") for supported models.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub think: Option<Think>,
|
||||
|
||||
/// Runtime options that control text generation
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<Options>,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
@@ -62,6 +66,7 @@ impl GenerateRequestBuilder {
|
||||
images: vec![],
|
||||
format: None,
|
||||
think: None,
|
||||
options: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -101,6 +106,11 @@ impl GenerateRequestBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn options(mut self, options: Options) -> Self {
|
||||
self.generate_request.options = Some(options);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> GenerateRequest {
|
||||
self.generate_request
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user