Add builders for GenerateContent and CountToken

This commit is contained in:
2024-11-26 20:27:36 +00:00
parent e6de1d1ce7
commit 326b3919d1
11 changed files with 128 additions and 51 deletions

View File

@@ -15,12 +15,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
);
let prompt = "What is the airspeed of an unladen swallow?";
let request = CountTokensRequest {
contents: Content {
role: Some("user".to_string()),
parts: Some(vec![Part::Text(prompt.to_string())]),
},
};
let request = CountTokensRequestBuilder::from_prompt(prompt).build();
let result = gemini.count_tokens(&request, "gemini-pro").await?;
println!("Response: {:?}", result);

View File

@@ -19,7 +19,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
tools: Some(vec![Tools {

View File

@@ -18,7 +18,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let prompt = "Generate 10 ideas of blog posts with a title and decription for each idea.";
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config: Some(GenerationConfig {

View File

@@ -20,7 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
safety_settings: Some(vec![SafetySetting {

View File

@@ -20,7 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
system_instruction: Some(Content {

View File

@@ -15,8 +15,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
);
let prompt = "What is the airspeed of an unladen swallow?";
let result = gemini.prompt_text(prompt, None).await?;
println!("Response: {}", result);
let request = GenerateContentRequest::builder().with_prompt(prompt).build();
let response = gemini.generate_content(&request, "gemini-pro").await?;
println!("Response: {:?}", response.candidates[0].get_text().unwrap());
Ok(())
}

View File

@@ -5,14 +5,14 @@ use futures_util::stream::StreamExt;
use reqwest_eventsource::{Event, EventSource};
use tracing::error;
use crate::dialogue::{Message, Role};
use crate::dialogue::Message;
use crate::error::{Error, Result};
use crate::prelude::{
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
TextEmbeddingResponse,
};
use crate::types::{PredictImageRequest, PredictImageResponse};
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
use crate::{prelude::Part, token_provider::TokenProvider};
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
@@ -154,7 +154,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
contents: messages
.iter()
.map(|m| Content {
role: Some(m.role.to_string()),
role: Some(m.role),
parts: Some(vec![Part::Text(m.text.clone())]),
})
.collect(),
@@ -177,6 +177,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
/// from the response.
#[deprecated(note = "Use `generate_content` instead")]
pub async fn prompt_text(
&self,
prompt: &str,
@@ -184,7 +185,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
) -> Result<String> {
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config: generation_config.cloned(),

View File

@@ -1,35 +1,6 @@
use std::str::FromStr;
use serde::{Deserialize, Serialize};
use crate::{client::GeminiClient, error::Result, prelude::TokenProvider};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum Role {
User,
Model,
}
impl ToString for Role {
fn to_string(&self) -> String {
match self {
Role::User => "user".to_string(),
Role::Model => "model".to_string(),
}
}
}
impl FromStr for Role {
type Err = ();
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"user" => Ok(Role::User),
"model" => Ok(Role::Model),
_ => Err(()),
}
}
}
use crate::{client::GeminiClient, error::Result, prelude::TokenProvider, types::Role};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Message {

View File

@@ -1,10 +1,10 @@
use std::collections::HashMap;
use std::{collections::HashMap, str::FromStr};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct Content {
pub role: Option<String>,
pub role: Option<Role>,
pub parts: Option<Vec<Part>>,
}
@@ -22,6 +22,34 @@ impl Content {
}
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Model,
}
impl ToString for Role {
fn to_string(&self) -> String {
match self {
Role::User => "user".to_string(),
Role::Model => "model".to_string(),
}
}
}
impl FromStr for Role {
type Err = ();
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"user" => Ok(Role::User),
"model" => Ok(Role::Model),
_ => Err(()),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum Part {

View File

@@ -7,6 +7,39 @@ pub struct CountTokensRequest {
pub contents: Content,
}
impl CountTokensRequest {
pub fn builder() -> CountTokensRequestBuilder {
CountTokensRequestBuilder::new()
}
}
pub struct CountTokensRequestBuilder {
contents: Content,
}
impl CountTokensRequestBuilder {
pub fn new() -> Self {
CountTokensRequestBuilder {
contents: Content::default(),
}
}
pub fn from_prompt(prompt: &str) -> Self {
CountTokensRequestBuilder {
contents: Content {
parts: Some(vec![super::Part::Text(prompt.to_string())]),
..Default::default()
},
}
}
pub fn build(self) -> CountTokensRequest {
CountTokensRequest {
contents: self.contents,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CountTokensResponse {

View File

@@ -3,7 +3,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::{Content, Part, VertexApiError};
use super::{Content, Part, Role, VertexApiError};
use crate::error::Result;
#[derive(Clone, Default, Serialize, Deserialize)]
@@ -24,7 +24,7 @@ impl GenerateContentRequest {
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config,
@@ -33,6 +33,54 @@ impl GenerateContentRequest {
safety_settings: None,
}
}
pub fn builder() -> GenerateContentRequestBuilder {
GenerateContentRequestBuilder::new()
}
}
pub struct GenerateContentRequestBuilder {
request: GenerateContentRequest,
}
impl GenerateContentRequestBuilder {
fn new() -> Self {
GenerateContentRequestBuilder {
request: GenerateContentRequest::default(),
}
}
pub fn with_prompt(mut self, prompt: &str) -> Self {
self.request.contents = vec![Content {
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}];
self
}
pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self {
self.request.generation_config = Some(generation_config);
self
}
pub fn with_tools(mut self, tools: Vec<Tools>) -> Self {
self.request.tools = Some(tools);
self
}
pub fn with_safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
self.request.safety_settings = Some(safety_settings);
self
}
pub fn with_system_instruction(mut self, system_instruction: Content) -> Self {
self.request.system_instruction = Some(system_instruction);
self
}
pub fn build(self) -> GenerateContentRequest {
self.request
}
}
#[derive(Clone, Default, Serialize, Deserialize)]