Add builders for GenerateContent and CountToken
This commit is contained in:
@@ -5,14 +5,14 @@ use futures_util::stream::StreamExt;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use tracing::error;
|
||||
|
||||
use crate::dialogue::{Message, Role};
|
||||
use crate::dialogue::Message;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::{
|
||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
||||
TextEmbeddingResponse,
|
||||
};
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse};
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
@@ -154,7 +154,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
contents: messages
|
||||
.iter()
|
||||
.map(|m| Content {
|
||||
role: Some(m.role.to_string()),
|
||||
role: Some(m.role),
|
||||
parts: Some(vec![Part::Text(m.text.clone())]),
|
||||
})
|
||||
.collect(),
|
||||
@@ -177,6 +177,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
|
||||
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
||||
/// from the response.
|
||||
#[deprecated(note = "Use `generate_content` instead")]
|
||||
pub async fn prompt_text(
|
||||
&self,
|
||||
prompt: &str,
|
||||
@@ -184,7 +185,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
) -> Result<String> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
generation_config: generation_config.cloned(),
|
||||
|
||||
@@ -1,35 +1,6 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{client::GeminiClient, error::Result, prelude::TokenProvider};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum Role {
|
||||
User,
|
||||
Model,
|
||||
}
|
||||
|
||||
impl ToString for Role {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Model => "model".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Role {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s {
|
||||
"user" => Ok(Role::User),
|
||||
"model" => Ok(Role::Model),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
use crate::{client::GeminiClient, error::Result, prelude::TokenProvider, types::Role};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, str::FromStr};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
|
||||
pub struct Content {
|
||||
pub role: Option<String>,
|
||||
pub role: Option<Role>,
|
||||
pub parts: Option<Vec<Part>>,
|
||||
}
|
||||
|
||||
@@ -22,6 +22,34 @@ impl Content {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Model,
|
||||
}
|
||||
|
||||
impl ToString for Role {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Model => "model".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Role {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s {
|
||||
"user" => Ok(Role::User),
|
||||
"model" => Ok(Role::Model),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Part {
|
||||
|
||||
@@ -7,6 +7,39 @@ pub struct CountTokensRequest {
|
||||
pub contents: Content,
|
||||
}
|
||||
|
||||
impl CountTokensRequest {
|
||||
pub fn builder() -> CountTokensRequestBuilder {
|
||||
CountTokensRequestBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CountTokensRequestBuilder {
|
||||
contents: Content,
|
||||
}
|
||||
|
||||
impl CountTokensRequestBuilder {
|
||||
pub fn new() -> Self {
|
||||
CountTokensRequestBuilder {
|
||||
contents: Content::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_prompt(prompt: &str) -> Self {
|
||||
CountTokensRequestBuilder {
|
||||
contents: Content {
|
||||
parts: Some(vec![super::Part::Text(prompt.to_string())]),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build(self) -> CountTokensRequest {
|
||||
CountTokensRequest {
|
||||
contents: self.contents,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CountTokensResponse {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::{Content, Part, VertexApiError};
|
||||
use super::{Content, Part, Role, VertexApiError};
|
||||
use crate::error::Result;
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
@@ -24,7 +24,7 @@ impl GenerateContentRequest {
|
||||
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
|
||||
GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
generation_config,
|
||||
@@ -33,6 +33,54 @@ impl GenerateContentRequest {
|
||||
safety_settings: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builder() -> GenerateContentRequestBuilder {
|
||||
GenerateContentRequestBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerateContentRequestBuilder {
|
||||
request: GenerateContentRequest,
|
||||
}
|
||||
|
||||
impl GenerateContentRequestBuilder {
|
||||
fn new() -> Self {
|
||||
GenerateContentRequestBuilder {
|
||||
request: GenerateContentRequest::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_prompt(mut self, prompt: &str) -> Self {
|
||||
self.request.contents = vec![Content {
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}];
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self {
|
||||
self.request.generation_config = Some(generation_config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tools(mut self, tools: Vec<Tools>) -> Self {
|
||||
self.request.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
|
||||
self.request.safety_settings = Some(safety_settings);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_system_instruction(mut self, system_instruction: Content) -> Self {
|
||||
self.request.system_instruction = Some(system_instruction);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> GenerateContentRequest {
|
||||
self.request
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
|
||||
Reference in New Issue
Block a user