Little Ashish is doing internship at multiple places. Instead of giving parties to his friends he decided to donate candies to children. He likes solving puzzles and playing games. Hence he plays a small game. Suppose there are N children. The rules of the game are:
The binary search approach
Consider a list of integers `
[S1,S2,S3,…]` where:
`
Si` = Total number of candies donated by Little Ashish till the `
ith` child.
This list consists of numbers in increasing order.
Now, we want to find the maximum number of children `
n` that we can serve; in other words, the largest integer `
n` such that `
Sn≤X`.
Whenever we have a collection (list) which is either in increasing or decreasing order, we can apply
binary search to find the numbers which are equal to, greater than, or less than a given number.
The only thing thing that remains is figuring out the details. Look at the following code snippet to get a better idea of the approach (also recommended is checking the setter's code):
def answer(X):
L = 1
R = 1000000 # a really large number which should be
# much greater than the actual answer, i.e.
# S(R) > X
while R - L > 1:
M = (L + R)/2
if S(M) <= X:
L = M
else:
R = M
return L
Note that this function assumes you have already implemented the function S(i) which should return the value `Si` quickly.
Addendum: Formula for `Si`
Let us try to find a formula for `Si`. Since the `ith` child requires `i2` candies, we have:
`
Si=12+22+32+⋯+i2
` These are well known of numbers known as the
square pyramidal numbers, and one can derive a formula for `
Si` using the following manipulations:
`
‘(i+1)3(i+1)3(i+1)3(i+1)3(i+1)33Si=∑j=0i[(j+1)3−j3] (Telescoping series)=∑j=0i[j3+3j2+3j+1−j3]=∑j=0i[3j2+3j+1]=3⎡⎣∑j=0ij2⎤⎦+3⎡⎣∑j=0ij⎤⎦+⎡⎣∑j=0i1⎤⎦=3Si+3⎡⎣∑j=0ij⎤⎦+(i+1)=(i+1)3−3⎡⎣∑j=0ij⎤⎦−(i+1)‘
`
Now, the number `
Ti=∑ij=0j=0+1+2+⋯+i` is called a
triangular number whose formula can be derived as:
`
‘TiTi2Ti2TiTi=1+⋯+i=i+⋯+1=(i+1)+⋯+(i+1)=i(i+1)=i(i+1)2‘
`
Back to `Si`:
`
‘3Si3Si3Si3Si3Si3Si3SiSi=(i+1)3−3⎡⎣∑j=0ij⎤⎦−(i+1)=(i+1)3−3Ti−(i+1)=(i+1)3−3i(i+1)2−(i+1)=(i+1)[(i+1)2−3i2−1]=(i+1)[i2+2i+1−3i2−1]=(i+1)[i2+i2]=i(i+1)2[2i+1]=i(i+1)(2i+1)6‘
`
Therefore, `Si` is just `i(i+1)(2i+1)6`.
Addendum: Newton's method
We'll describe a faster, more advanced approach which we hope you'll find helpful.
Note that we are finding the largest `
n` such that `
S(n)≤X`. Now, since `
n` is the largest such integer, we have `
S(n+1)>X`. Therefore, the largest real solution `
r` of the equation `
S(x)=X` is in the interval `
[n,n+1)`, and since `
n≤r<n+1`, we have `
n=⌊r⌋`. Therefore, one approach is simply to find the largest real root `
r` or the equation `
r(r+1)(2r+1)6=X`, and taking the answer as `
n=⌊r⌋` :) You can use your favorite
solution-finding method for this. In this section, we'll use
Newton's method.
We want to find a solution x of the equation `
S(x)=X`. Since, `
S(x)=x(x+1)(2x+1)6`, we can rewrite this equation as `
x(x+1)(2x+1)−6X=0`, so we are now finding a root of the polynomial `
f(x)=x(x+1)(2x+1)−6X`. Newton's method is an iterative method that starts with an initial guess `
x0`, and updating it using the following rule:
`
xi+1=xi−f(xi)f′(xi)
`
where `
f′` is the derivative of `
f`. One does this iteration until the `
xi` converges.
Now, the derivative of `
f(x)=2x3+3x2+x−6X` is `
f′(x)=6x2+6x+1=6x(x+1)+1`. Therefore, our update rule is the following:
`
xi+1=xi−xi(xi+1)(2xi+1)−6X6xi(xi+1)+1
`
and finally, `
f(x)≈2x3−6X`, so a good initial guess would be `
x0=3X−−−√3`.
All of this is implemented in the following code snippet:
def answer(X):
r = pow(3*X, 1/3.) # approximate solution
while True:
nr = r - (r*(r+1)*(2*r+1) - 6*X)/(6*r*(r+1) + 1)
if abs(r - nr) < 1e-6: # converged
return int(nr)
r = nr
#include<bits/stdc++.h>
using namespace std;
typedef long long int lli;
#define test() int test_case;cin>>test_case;while(test_case--)
#define fr(i,n) for(int i=0;i<n;i++)
#define frr(i,a,n) for(int p=i;p<n;p++)
#define sll(a) scanf("%lld",&a)
#define sl(a) scanf("%ld",&a)
#define si(a) scanf("%i",&a)
#define sd(a) scanf("%ld",&a)
#define sf(a) scanf("%f",&a)
#define rn(a) return a
#define pai pair<int,int>
#define pal pair<li,li>
#define pall pair<lli,lli>
#define ff first
#define ss second
#define mod 1000000007
#define mp make_pair
#define pb push_back
#define pll(a) printf("%lld\n",a);
#define pl(a) printf("%lld\n",a);
#define pi(a) printf("%d\n",a);
#define pd(a) printf("%lf\n",a);
#define pf(a) printf("%f\n",a);
lli ans=0;
list<lli> li[1000];
lli visited[10000];
lli func(lli num)
{
return (num*(num+1)*(2*num+1))/6;
}
int main()
{
test()
{
lli x;
lli ans=0;
cin>>x;
lli mini=1;
lli maxi=1000000;
while(mini<maxi)
{
lli mid=(mini+maxi)/2;
if(func(mid)>x)
{
maxi=mid;
}
else
{
ans=max(ans,mid);
mini=mid+1;
}
}
cout<<ans<<endl;
}
return 0;
}