Add Options to generate and chat
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::types::common::Options;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum Role {
|
pub enum Role {
|
||||||
@@ -17,8 +19,16 @@ pub struct Message {
|
|||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct ChatRequest {
|
pub struct ChatRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||||
pub messages: Vec<Message>,
|
pub messages: Vec<Message>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stream: Option<bool>,
|
pub stream: Option<bool>,
|
||||||
|
|
||||||
|
/// Runtime options that control text generation
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub options: Option<Options>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatRequest {
|
impl ChatRequest {
|
||||||
@@ -46,6 +56,7 @@ impl ChatRequestBuilder {
|
|||||||
model: model.into(),
|
model: model.into(),
|
||||||
messages: vec![],
|
messages: vec![],
|
||||||
stream: None,
|
stream: None,
|
||||||
|
options: None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -55,6 +66,11 @@ impl ChatRequestBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn options(mut self, options: Options) -> Self {
|
||||||
|
self.chat_request.options = Some(options);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn build(self) -> ChatRequest {
|
pub fn build(self) -> ChatRequest {
|
||||||
self.chat_request
|
self.chat_request
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,3 +23,103 @@ pub enum ThinkLevel {
|
|||||||
Medium,
|
Medium,
|
||||||
Low,
|
Low,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||||
|
pub struct Options {
|
||||||
|
/// Random seed used for reproducible outputs
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
|
||||||
|
/// Controls randomness in generation (higher = more random)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// Limits next token selection to the K most likely
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<u32>,
|
||||||
|
|
||||||
|
/// Cumulative probability threshold for nucleus sampling
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Minimum probability threshold for token selection
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Stop sequences that will halt generation
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Stop>,
|
||||||
|
|
||||||
|
/// Context length size (number of tokens)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub num_ctx: Option<u32>,
|
||||||
|
|
||||||
|
/// Maximum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub num_predict: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Options {
|
||||||
|
pub fn builder() -> OptionsBuilder {
|
||||||
|
OptionsBuilder {
|
||||||
|
options: Options::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OptionsBuilder {
|
||||||
|
options: Options,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OptionsBuilder {
|
||||||
|
pub fn seed(mut self, seed: u64) -> Self {
|
||||||
|
self.options.seed = Some(seed);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn temperature(mut self, temperature: f32) -> Self {
|
||||||
|
self.options.temperature = Some(temperature);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn top_k(mut self, top_k: u32) -> Self {
|
||||||
|
self.options.top_k = Some(top_k);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn top_p(mut self, top_p: f32) -> Self {
|
||||||
|
self.options.top_p = Some(top_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn min_p(mut self, min_p: f32) -> Self {
|
||||||
|
self.options.min_p = Some(min_p);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stop(mut self, stop: Stop) -> Self {
|
||||||
|
self.options.stop = Some(stop);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_ctx(mut self, num_ctx: u32) -> Self {
|
||||||
|
self.options.num_ctx = Some(num_ctx);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_predict(mut self, num_predict: u32) -> Self {
|
||||||
|
self.options.num_predict = Some(num_predict);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> Options {
|
||||||
|
self.options
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum Stop {
|
||||||
|
Single(String),
|
||||||
|
Multiple(Vec<String>),
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::types::common::Think;
|
use crate::types::common::{Options, Think};
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct GenerateRequest {
|
pub struct GenerateRequest {
|
||||||
@@ -38,6 +38,10 @@ pub struct GenerateRequest {
|
|||||||
/// (true/false) or a string ("high", "medium", "low") for supported models.
|
/// (true/false) or a string ("high", "medium", "low") for supported models.
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub think: Option<Think>,
|
pub think: Option<Think>,
|
||||||
|
|
||||||
|
/// Runtime options that control text generation
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub options: Option<Options>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GenerateRequest {
|
impl GenerateRequest {
|
||||||
@@ -62,6 +66,7 @@ impl GenerateRequestBuilder {
|
|||||||
images: vec![],
|
images: vec![],
|
||||||
format: None,
|
format: None,
|
||||||
think: None,
|
think: None,
|
||||||
|
options: None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,6 +106,11 @@ impl GenerateRequestBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn options(mut self, options: Options) -> Self {
|
||||||
|
self.generate_request.options = Some(options);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn build(self) -> GenerateRequest {
|
pub fn build(self) -> GenerateRequest {
|
||||||
self.generate_request
|
self.generate_request
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user