Add builders for GenerateContent and CountToken
This commit is contained in:
@@ -15,12 +15,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let prompt = "What is the airspeed of an unladen swallow?";
|
let prompt = "What is the airspeed of an unladen swallow?";
|
||||||
let request = CountTokensRequest {
|
let request = CountTokensRequestBuilder::from_prompt(prompt).build();
|
||||||
contents: Content {
|
|
||||||
role: Some("user".to_string()),
|
|
||||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let result = gemini.count_tokens(&request, "gemini-pro").await?;
|
let result = gemini.count_tokens(&request, "gemini-pro").await?;
|
||||||
println!("Response: {:?}", result);
|
println!("Response: {:?}", result);
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let request = GenerateContentRequest {
|
let request = GenerateContentRequest {
|
||||||
contents: vec![Content {
|
contents: vec![Content {
|
||||||
role: Some("user".to_string()),
|
role: Some(Role::User),
|
||||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
}],
|
}],
|
||||||
tools: Some(vec![Tools {
|
tools: Some(vec![Tools {
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let prompt = "Generate 10 ideas of blog posts with a title and decription for each idea.";
|
let prompt = "Generate 10 ideas of blog posts with a title and decription for each idea.";
|
||||||
let request = GenerateContentRequest {
|
let request = GenerateContentRequest {
|
||||||
contents: vec![Content {
|
contents: vec![Content {
|
||||||
role: Some("user".to_string()),
|
role: Some(Role::User),
|
||||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
}],
|
}],
|
||||||
generation_config: Some(GenerationConfig {
|
generation_config: Some(GenerationConfig {
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let request = GenerateContentRequest {
|
let request = GenerateContentRequest {
|
||||||
contents: vec![Content {
|
contents: vec![Content {
|
||||||
role: Some("user".to_string()),
|
role: Some(Role::User),
|
||||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
}],
|
}],
|
||||||
safety_settings: Some(vec![SafetySetting {
|
safety_settings: Some(vec![SafetySetting {
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let request = GenerateContentRequest {
|
let request = GenerateContentRequest {
|
||||||
contents: vec![Content {
|
contents: vec![Content {
|
||||||
role: Some("user".to_string()),
|
role: Some(Role::User),
|
||||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
}],
|
}],
|
||||||
system_instruction: Some(Content {
|
system_instruction: Some(Content {
|
||||||
|
|||||||
@@ -15,8 +15,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let prompt = "What is the airspeed of an unladen swallow?";
|
let prompt = "What is the airspeed of an unladen swallow?";
|
||||||
let result = gemini.prompt_text(prompt, None).await?;
|
let request = GenerateContentRequest::builder().with_prompt(prompt).build();
|
||||||
println!("Response: {}", result);
|
let response = gemini.generate_content(&request, "gemini-pro").await?;
|
||||||
|
println!("Response: {:?}", response.candidates[0].get_text().unwrap());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,14 +5,14 @@ use futures_util::stream::StreamExt;
|
|||||||
use reqwest_eventsource::{Event, EventSource};
|
use reqwest_eventsource::{Event, EventSource};
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::dialogue::{Message, Role};
|
use crate::dialogue::Message;
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::prelude::{
|
use crate::prelude::{
|
||||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
||||||
TextEmbeddingResponse,
|
TextEmbeddingResponse,
|
||||||
};
|
};
|
||||||
use crate::types::{PredictImageRequest, PredictImageResponse};
|
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
||||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||||
|
|
||||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||||
@@ -154,7 +154,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
contents: messages
|
contents: messages
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| Content {
|
.map(|m| Content {
|
||||||
role: Some(m.role.to_string()),
|
role: Some(m.role),
|
||||||
parts: Some(vec![Part::Text(m.text.clone())]),
|
parts: Some(vec![Part::Text(m.text.clone())]),
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
@@ -177,6 +177,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
|
|
||||||
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
||||||
/// from the response.
|
/// from the response.
|
||||||
|
#[deprecated(note = "Use `generate_content` instead")]
|
||||||
pub async fn prompt_text(
|
pub async fn prompt_text(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
@@ -184,7 +185,7 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let request = GenerateContentRequest {
|
let request = GenerateContentRequest {
|
||||||
contents: vec![Content {
|
contents: vec![Content {
|
||||||
role: Some("user".to_string()),
|
role: Some(Role::User),
|
||||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
}],
|
}],
|
||||||
generation_config: generation_config.cloned(),
|
generation_config: generation_config.cloned(),
|
||||||
|
|||||||
@@ -1,35 +1,6 @@
|
|||||||
use std::str::FromStr;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{client::GeminiClient, error::Result, prelude::TokenProvider};
|
use crate::{client::GeminiClient, error::Result, prelude::TokenProvider, types::Role};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
||||||
pub enum Role {
|
|
||||||
User,
|
|
||||||
Model,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ToString for Role {
|
|
||||||
fn to_string(&self) -> String {
|
|
||||||
match self {
|
|
||||||
Role::User => "user".to_string(),
|
|
||||||
Role::Model => "model".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromStr for Role {
|
|
||||||
type Err = ();
|
|
||||||
|
|
||||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
|
||||||
match s {
|
|
||||||
"user" => Ok(Role::User),
|
|
||||||
"model" => Ok(Role::Model),
|
|
||||||
_ => Err(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, str::FromStr};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
|
||||||
pub struct Content {
|
pub struct Content {
|
||||||
pub role: Option<String>,
|
pub role: Option<Role>,
|
||||||
pub parts: Option<Vec<Part>>,
|
pub parts: Option<Vec<Part>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,6 +22,34 @@ impl Content {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Role {
|
||||||
|
User,
|
||||||
|
Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToString for Role {
|
||||||
|
fn to_string(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Role::User => "user".to_string(),
|
||||||
|
Role::Model => "model".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromStr for Role {
|
||||||
|
type Err = ();
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||||
|
match s {
|
||||||
|
"user" => Ok(Role::User),
|
||||||
|
"model" => Ok(Role::Model),
|
||||||
|
_ => Err(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub enum Part {
|
pub enum Part {
|
||||||
|
|||||||
@@ -7,6 +7,39 @@ pub struct CountTokensRequest {
|
|||||||
pub contents: Content,
|
pub contents: Content,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CountTokensRequest {
|
||||||
|
pub fn builder() -> CountTokensRequestBuilder {
|
||||||
|
CountTokensRequestBuilder::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CountTokensRequestBuilder {
|
||||||
|
contents: Content,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CountTokensRequestBuilder {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
CountTokensRequestBuilder {
|
||||||
|
contents: Content::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_prompt(prompt: &str) -> Self {
|
||||||
|
CountTokensRequestBuilder {
|
||||||
|
contents: Content {
|
||||||
|
parts: Some(vec![super::Part::Text(prompt.to_string())]),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> CountTokensRequest {
|
||||||
|
CountTokensRequest {
|
||||||
|
contents: self.contents,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum CountTokensResponse {
|
pub enum CountTokensResponse {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use std::collections::HashMap;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use super::{Content, Part, VertexApiError};
|
use super::{Content, Part, Role, VertexApiError};
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
|
|
||||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||||
@@ -24,7 +24,7 @@ impl GenerateContentRequest {
|
|||||||
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
|
pub fn from_prompt(prompt: &str, generation_config: Option<GenerationConfig>) -> Self {
|
||||||
GenerateContentRequest {
|
GenerateContentRequest {
|
||||||
contents: vec![Content {
|
contents: vec![Content {
|
||||||
role: Some("user".to_string()),
|
role: Some(Role::User),
|
||||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
}],
|
}],
|
||||||
generation_config,
|
generation_config,
|
||||||
@@ -33,6 +33,54 @@ impl GenerateContentRequest {
|
|||||||
safety_settings: None,
|
safety_settings: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn builder() -> GenerateContentRequestBuilder {
|
||||||
|
GenerateContentRequestBuilder::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct GenerateContentRequestBuilder {
|
||||||
|
request: GenerateContentRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateContentRequestBuilder {
|
||||||
|
fn new() -> Self {
|
||||||
|
GenerateContentRequestBuilder {
|
||||||
|
request: GenerateContentRequest::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_prompt(mut self, prompt: &str) -> Self {
|
||||||
|
self.request.contents = vec![Content {
|
||||||
|
role: Some(Role::User),
|
||||||
|
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||||
|
}];
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self {
|
||||||
|
self.request.generation_config = Some(generation_config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_tools(mut self, tools: Vec<Tools>) -> Self {
|
||||||
|
self.request.tools = Some(tools);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
|
||||||
|
self.request.safety_settings = Some(safety_settings);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_system_instruction(mut self, system_instruction: Content) -> Self {
|
||||||
|
self.request.system_instruction = Some(system_instruction);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> GenerateContentRequest {
|
||||||
|
self.request
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||||
|
|||||||
Reference in New Issue
Block a user